Skip to content

Commit

Permalink
Fixed clustering for Conv3D.
Browse files Browse the repository at this point in the history
Change-Id: Ic1ee16760d9643b8142e5644d6cb68a6890bc621
  • Loading branch information
wwwind committed Jan 5, 2021
1 parent 973f5b3 commit 19b5322
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
Expand Up @@ -66,6 +66,8 @@ def setUp(self):
self.custom_clusterable_layer = CustomClusterableLayer(10)
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
self.keras_conv2d_layer =tf.keras.layers.Conv2D(filters=3, kernel_size=(4, 5))
self.keras_conv3d_layer =tf.keras.layers.Conv3D(filters=2, kernel_size=(3, 4, 5))

clustering_registry.ClusteringLookupRegistry.register_new_implementation(
{
Expand Down Expand Up @@ -141,6 +143,38 @@ def testDepthwiseConv2DLayerNonClusterable(self):
wrapped_layer)
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())

@keras_parameterized.run_all_keras_modes
def testConv2DLayer(self):
"""
Verifies that we can cluster a Conv2D layer.
"""
input_shape =(4, 28, 28, 1)
wrapped_layer = self._build_clustered_layer_model(
self.keras_conv2d_layer,
input_shape=input_shape
)

self._validate_clustered_layer(self.keras_conv2d_layer,
wrapped_layer)
self.assertEqual([4, 5, 1, 3],
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)

@keras_parameterized.run_all_keras_modes
def testConv3DLayer(self):
"""
Verifies that we can cluster a Conv3D layer.
"""
input_shape =(4, 28, 28, 28, 1)
wrapped_layer = self._build_clustered_layer_model(
self.keras_conv3d_layer,
input_shape=input_shape
)

self._validate_clustered_layer(self.keras_conv3d_layer,
wrapped_layer)
self.assertEqual([3, 4, 5, 1, 2],
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)

def testClusterKerasUnsupportedLayer(self):
"""
Verifies that attempting to cluster an unsupported layer raises an
Expand Down
Expand Up @@ -133,6 +133,30 @@ def get_pulling_indices(self, weight):

return pulling_indices

class ConvolutionalWeights3DCA(AbstractClusteringAlgorithm):
"""
Look-ups for convolutional 3D kernels, e.g. tensors with shape [B,D1,D2,D3,C]
"""

def get_pulling_indices(self, weight):
clst_num = self.cluster_centroids.shape[0]
tiled_weights = tf.tile(tf.expand_dims(weight, 5), [1, 1, 1, 1, 1, clst_num])

# Do the ugly reshape to the clustering points
tiled_cluster_centroids = tf.stack(
[tf.tile(tf.stack(
[tf.reshape(self.cluster_centroids, [1, 1, 1, clst_num])] *
weight.shape[-2], axis=3),
[weight.shape[0], weight.shape[1], weight.shape[2], 1, 1])] * weight.shape[-1],
axis=4)

# We find the nearest cluster centroids and store them so that ops can build
# their kernels upon it
pulling_indices = tf.argmin(
tf.abs(tiled_weights - tiled_cluster_centroids), axis=5
)

return pulling_indices

class DenseWeightsCA(AbstractClusteringAlgorithm):
"""
Expand Down Expand Up @@ -182,8 +206,8 @@ class ClusteringLookupRegistry(object):
layers.Conv1D: {'kernel': ConvolutionalWeightsCA},
layers.Conv2D: {'kernel': ConvolutionalWeightsCA},
layers.Conv2DTranspose: {'kernel': ConvolutionalWeightsCA},
layers.Conv3D: {'kernel': ConvolutionalWeightsCA},
layers.Conv3DTranspose: {'kernel': ConvolutionalWeightsCA},
layers.Conv3D: {'kernel': ConvolutionalWeights3DCA},
layers.Conv3DTranspose: {'kernel': ConvolutionalWeights3DCA},
layers.SeparableConv1D: {'pointwise_kernel': ConvolutionalWeightsCA},
layers.SeparableConv2D: {'pointwise_kernel': ConvolutionalWeightsCA},
layers.Dense: {'kernel': DenseWeightsCA},
Expand Down
Expand Up @@ -204,6 +204,7 @@ def testGetClusteringImplFailsWithKnonwClassUnknownWeight(self):
@parameterized.parameters(
(layers.Conv2D, 'kernel', clustering_registry.ConvolutionalWeightsCA),
(layers.Conv1D, 'kernel', clustering_registry.ConvolutionalWeightsCA),
(layers.Conv3D, 'kernel', clustering_registry.ConvolutionalWeights3DCA),
)
def testReturnsResultsForKnownTypeKnownWeights(self,
layer_type,
Expand Down

0 comments on commit 19b5322

Please sign in to comment.