From 19b532235b806685749292a684ae7ccc9d2af66f Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Mon, 4 Jan 2021 19:00:41 +0000 Subject: [PATCH] Fixed clustering for Conv3D. Change-Id: Ic1ee16760d9643b8142e5644d6cb68a6890bc621 --- .../core/clustering/keras/cluster_test.py | 34 +++++++++++++++++++ .../clustering/keras/clustering_registry.py | 28 +++++++++++++-- .../keras/clustering_registry_test.py | 1 + 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 0eb335940..44939ab2b 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -66,6 +66,8 @@ def setUp(self): self.custom_clusterable_layer = CustomClusterableLayer(10) self.custom_non_clusterable_layer = CustomNonClusterableLayer(10) self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1)) + self.keras_conv2d_layer =tf.keras.layers.Conv2D(filters=3, kernel_size=(4, 5)) + self.keras_conv3d_layer =tf.keras.layers.Conv3D(filters=2, kernel_size=(3, 4, 5)) clustering_registry.ClusteringLookupRegistry.register_new_implementation( { @@ -141,6 +143,38 @@ def testDepthwiseConv2DLayerNonClusterable(self): wrapped_layer) self.assertEqual([], wrapped_layer.layer.get_clusterable_weights()) + @keras_parameterized.run_all_keras_modes + def testConv2DLayer(self): + """ + Verifies that we can cluster a Conv2D layer. + """ + input_shape =(4, 28, 28, 1) + wrapped_layer = self._build_clustered_layer_model( + self.keras_conv2d_layer, + input_shape=input_shape + ) + + self._validate_clustered_layer(self.keras_conv2d_layer, + wrapped_layer) + self.assertEqual([4, 5, 1, 3], + wrapped_layer.layer.get_clusterable_weights()[0][1].shape) + + @keras_parameterized.run_all_keras_modes + def testConv3DLayer(self): + """ + Verifies that we can cluster a Conv3D layer. + """ + input_shape =(4, 28, 28, 28, 1) + wrapped_layer = self._build_clustered_layer_model( + self.keras_conv3d_layer, + input_shape=input_shape + ) + + self._validate_clustered_layer(self.keras_conv3d_layer, + wrapped_layer) + self.assertEqual([3, 4, 5, 1, 2], + wrapped_layer.layer.get_clusterable_weights()[0][1].shape) + def testClusterKerasUnsupportedLayer(self): """ Verifies that attempting to cluster an unsupported layer raises an diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py index e36ae8f56..a5ed54e0b 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py @@ -133,6 +133,30 @@ def get_pulling_indices(self, weight): return pulling_indices +class ConvolutionalWeights3DCA(AbstractClusteringAlgorithm): + """ + Look-ups for convolutional 3D kernels, e.g. tensors with shape [B,D1,D2,D3,C] + """ + + def get_pulling_indices(self, weight): + clst_num = self.cluster_centroids.shape[0] + tiled_weights = tf.tile(tf.expand_dims(weight, 5), [1, 1, 1, 1, 1, clst_num]) + + # Do the ugly reshape to the clustering points + tiled_cluster_centroids = tf.stack( + [tf.tile(tf.stack( + [tf.reshape(self.cluster_centroids, [1, 1, 1, clst_num])] * + weight.shape[-2], axis=3), + [weight.shape[0], weight.shape[1], weight.shape[2], 1, 1])] * weight.shape[-1], + axis=4) + + # We find the nearest cluster centroids and store them so that ops can build + # their kernels upon it + pulling_indices = tf.argmin( + tf.abs(tiled_weights - tiled_cluster_centroids), axis=5 + ) + + return pulling_indices class DenseWeightsCA(AbstractClusteringAlgorithm): """ @@ -182,8 +206,8 @@ class ClusteringLookupRegistry(object): layers.Conv1D: {'kernel': ConvolutionalWeightsCA}, layers.Conv2D: {'kernel': ConvolutionalWeightsCA}, layers.Conv2DTranspose: {'kernel': ConvolutionalWeightsCA}, - layers.Conv3D: {'kernel': ConvolutionalWeightsCA}, - layers.Conv3DTranspose: {'kernel': ConvolutionalWeightsCA}, + layers.Conv3D: {'kernel': ConvolutionalWeights3DCA}, + layers.Conv3DTranspose: {'kernel': ConvolutionalWeights3DCA}, layers.SeparableConv1D: {'pointwise_kernel': ConvolutionalWeightsCA}, layers.SeparableConv2D: {'pointwise_kernel': ConvolutionalWeightsCA}, layers.Dense: {'kernel': DenseWeightsCA}, diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py index ac60fe1a5..fca72ab03 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py @@ -204,6 +204,7 @@ def testGetClusteringImplFailsWithKnonwClassUnknownWeight(self): @parameterized.parameters( (layers.Conv2D, 'kernel', clustering_registry.ConvolutionalWeightsCA), (layers.Conv1D, 'kernel', clustering_registry.ConvolutionalWeightsCA), + (layers.Conv3D, 'kernel', clustering_registry.ConvolutionalWeights3DCA), ) def testReturnsResultsForKnownTypeKnownWeights(self, layer_type,