From 2f9d87438c63a1160b8546679756b1c5766ff578 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Tue, 13 Apr 2021 13:44:30 +0100 Subject: [PATCH] Small fixes after merging with master. Change-Id: I51ce035855ff8e82c339f1cd260b7f5891050ab7 --- .../python/core/clustering/keras/cluster_test.py | 1 - .../python/core/clustering/keras/clusterable_layer.py | 2 -- .../python/core/clustering/keras/clustering_registry.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 4f9703f46..0c71f16f8 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -128,7 +128,6 @@ def setUp(self): self.keras_custom_layer = KerasCustomLayer() self.clusterable_layer = MyClusterableLayer(10) - clustering_registry.ClusteringLookupRegistry.register_new_implementation( { CustomClusterableLayer: { diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py b/tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py index 5115442d5..0d4ce9306 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clusterable_layer.py @@ -55,10 +55,8 @@ def get_clusterable_algorithm(self, weight_name): # pylint: disable=unused-argu The returned class should be derived from AbstractClusteringAlgorithm and implements the function get_pulling_indices. - This function is used to provide a special lookup function for the custom weights. - It reshapes and tile centroids the same way as the weights. This allows us to find pulling indices efficiently. diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py index a380e9397..796edba75 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py @@ -23,7 +23,6 @@ AbstractClusteringAlgorithm = clustering_algorithm.AbstractClusteringAlgorithm - class ConvolutionalWeightsCA(AbstractClusteringAlgorithm): """Look-ups for convolutional kernels, e.g. tensors with shape [B,W,H,C].""" @@ -80,6 +79,7 @@ def get_pulling_indices(self, weight): return pulling_indices + class ClusteringLookupRegistry(object): """Map of layers to strategy.