diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 6e3319211..0c71f16f8 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -124,6 +124,7 @@ def setUp(self): self.custom_clusterable_layer = CustomClusterableLayer(10) self.custom_non_clusterable_layer = CustomNonClusterableLayer(10) self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1)) + self.clusterable_layer = MyClusterableLayer(10) self.keras_custom_layer = KerasCustomLayer() self.clusterable_layer = MyClusterableLayer(10) @@ -242,6 +243,34 @@ def testClusterCustomNonClusterableLayer(self): cluster_wrapper.ClusterWeights(custom_non_clusterable_layer, **self.params) + def testClusterMyClusterableLayer(self): + # we have weights to cluster. + clusterable_layer = self.clusterable_layer + clusterable_layer.build(input_shape=(10, 10)) + + wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer, + **self.params) + + self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights) + + def testKerasCustomLayerClusterable(self): + """ + Verifies that we can wrap keras custom layer that is customerable. + """ + clusterable_layer = KerasCustomLayerClusterable() + wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer, + **self.params) + + self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights) + + def testClusterMyClusterableLayerInvalid(self): + """ + Verifies that assertion is thrown when function + get_clusterable_weights is not provided. + """ + with self.assertRaises(TypeError): + MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated + def testClusterKerasCustomLayer(self): """Verifies that attempting to cluster a keras custom layer raises an exception.""" # If layer is not built, it has not weights, so @@ -275,8 +304,7 @@ def testClusterMyClusterableLayerInvalid(self): @keras_parameterized.run_all_keras_modes def testClusterSequentialModelSelectively(self): clustered_model = keras.Sequential() - clustered_model.add( - cluster.cluster_weights(self.keras_clusterable_layer, **self.params)) + clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **self.params)) clustered_model.add(self.keras_clusterable_layer) clustered_model.build(input_shape=(1, 10)) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py index 1be5da6c5..955044dbf 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py @@ -21,7 +21,7 @@ @six.add_metaclass(abc.ABCMeta) class AbstractClusteringAlgorithm(object): - """Abstrac class to implement highly efficient vectorised look-ups. + """Abstract class to implement highly efficient vectorised look-ups. We do not utilise looping for that purpose, instead we `smartly` reshape and tile arrays. The trade-off is that we are potentially using way more memory