diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 0eb335940..5cf42eca2 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -53,6 +53,25 @@ def get_clusterable_weights(self): class CustomNonClusterableLayer(layers.Dense): pass +class KerasCustomLayer(keras.layers.Layer): + def __init__(self, units=32): + super(KerasCustomLayer, self).__init__() + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), + initializer="random_normal", + trainable=False + ) + + def call(self, inputs): + return tf.matmul(inputs, self.w) + self.b class ClusterTest(test.TestCase, parameterized.TestCase): """Unit tests for the cluster module.""" @@ -66,6 +85,7 @@ def setUp(self): self.custom_clusterable_layer = CustomClusterableLayer(10) self.custom_non_clusterable_layer = CustomNonClusterableLayer(10) self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1)) + self.keras_custom_layer = KerasCustomLayer() clustering_registry.ClusteringLookupRegistry.register_new_implementation( { @@ -179,6 +199,22 @@ def testClusterCustomNonClusterableLayer(self): cluster_wrapper.ClusterWeights(custom_non_clusterable_layer, **self.params) + def testClusterKerasCustomLayer(self): + """ + Verifies that attempting to cluster a keras custom layer raises + an exception. + """ + # If layer is not built, it has not weights, so + # we just skip it. + keras_custom_layer = self.keras_custom_layer + cluster_wrapper.ClusterWeights(keras_custom_layer, + **self.params) + # We need to build weights before check that clustering is not supported. + keras_custom_layer.build(input_shape=(10, 10)) + with self.assertRaises(ValueError): + cluster_wrapper.ClusterWeights(keras_custom_layer, + **self.params) + @keras_parameterized.run_all_keras_modes def testClusterSequentialModelSelectively(self): """