From 2e70a9e8d53bb4ef6d3a4c0fe1c2799b39e8088b Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Fri, 29 May 2020 11:55:33 +0100 Subject: [PATCH] Improve configuration object for clustering API * Added enum CentroidInitialization to be used in clustering params to specify centroid initialization method (instead of a string) * Updated serialization and unit tests accordingly --- .../python/core/clustering/keras/BUILD | 14 ++++ .../python/core/clustering/keras/cluster.py | 24 +++--- .../core/clustering/keras/cluster_config.py | 23 ++++++ .../keras/cluster_integration_test.py | 5 +- .../core/clustering/keras/cluster_test.py | 4 +- .../core/clustering/keras/cluster_wrapper.py | 4 +- .../clustering/keras/cluster_wrapper_test.py | 76 ++++++++++++------- .../clustering/keras/clustering_centroids.py | 19 +++-- .../keras/clustering_centroids_test.py | 21 +++-- 9 files changed, 138 insertions(+), 52 deletions(-) create mode 100644 tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index c5bbcd7af..1d97c1437 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -23,12 +23,20 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":cluster_config", ":cluster_wrapper", ":clustering_centroids", ":clustering_registry", ], ) +py_library( + name = "cluster_config", + srcs = ["cluster_config.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + py_library( name = "clustering_registry", srcs = ["clustering_registry.py"], @@ -51,6 +59,9 @@ py_library( srcs = ["clustering_centroids.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = [ + ":cluster_config", + ], ) py_library( @@ -58,6 +69,9 @@ py_library( srcs = ["cluster_wrapper.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = [ + ":cluster_config", + ], ) py_test( diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py index cd3506aee..45357c7da 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py @@ -17,8 +17,9 @@ from tensorflow import keras from tensorflow.keras import initializers -from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper +from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids k = keras.backend CustomObjectScope = keras.utils.CustomObjectScope @@ -47,7 +48,7 @@ def cluster_scope(): """ return CustomObjectScope( { - 'ClusterWeights': cluster_wrapper.ClusterWeights + 'ClusterWeights' : cluster_wrapper.ClusterWeights } ) @@ -79,7 +80,8 @@ def cluster_weights(to_cluster, ```python clustering_params = { 'number_of_clusters': 8, - 'cluster_centroids_init': 'density-based' + 'cluster_centroids_init': + cluster_config.CentroidInitialization.DENSITY_BASED } clustered_model = cluster_weights(original_model, **clustering_params) @@ -90,7 +92,8 @@ def cluster_weights(to_cluster, ```python clustering_params = { 'number_of_clusters': 8, - 'cluster_centroids_init': 'density-based' + 'cluster_centroids_init': + cluster_config.CentroidInitialization.DENSITY_BASED } model = keras.Sequential([ @@ -105,15 +108,16 @@ def cluster_weights(to_cluster, number_of_clusters: the number of cluster centroids to form when clustering a layer/model. For example, if number_of_clusters=8 then only 8 unique values will be used in each weight array. - cluster_centroids_init: how to initialize the cluster centroids. + cluster_centroids_init: enum value that determines how the cluster + centroids will be initialized. Can have following values: - 1. 'random' : centroids are sampled using the uniform distribution + 1. RANDOM : centroids are sampled using the uniform distribution between the minimum and maximum weight values in a given layer - 2. 'density-based' : density-based sampling. First, cumulative + 2. DENSITY_BASED : density-based sampling. First, cumulative distribution function is built for weights, then y-axis is evenly spaced into number_of_clusters regions. After this the corresponding x values are obtained and used to initialize clusters centroids. - 3. 'linear' : cluster centroids are evenly spaced between the minimum + 3. LINEAR : cluster centroids are evenly spaced between the minimum and maximum values of a given weight **kwargs: Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer. @@ -127,8 +131,8 @@ def cluster_weights(to_cluster, """ if not clustering_centroids.CentroidsInitializerFactory.\ init_is_supported(cluster_centroids_init): - raise ValueError("cluster centroids can only be one of three values: " - "random, density-based, linear") + raise ValueError("Cluster centroid initialization {} not supported".\ + format(cluster_centroids_init)) def _add_clustering_wrapper(layer): if isinstance(layer, cluster_wrapper.ClusterWeights): diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py new file mode 100644 index 000000000..75c0f6653 --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py @@ -0,0 +1,23 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration classes for clustering.""" + +import enum + + +class CentroidInitialization(str, enum.Enum): + LINEAR = "LINEAR" + RANDOM = "RANDOM" + DENSITY_BASED = "DENSITY_BASED" diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index 4db369f21..1c9af9f29 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -19,12 +19,15 @@ from absl.testing import parameterized from tensorflow.python.keras import keras_parameterized + from tensorflow_model_optimization.python.core.clustering.keras import cluster +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config keras = tf.keras layers = keras.layers test = tf.test +CentroidInitialization = cluster_config.CentroidInitialization class ClusterIntegrationTest(test.TestCase, parameterized.TestCase): """Integration tests for clustering.""" @@ -43,7 +46,7 @@ def testValuesRemainClusteredAfterTraining(self): clustered_model = cluster.cluster_weights( original_model, number_of_clusters=number_of_clusters, - cluster_centroids_init='linear' + cluster_centroids_init=CentroidInitialization.LINEAR ) clustered_model.compile( diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 38c805973..7e2f0b949 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -21,6 +21,7 @@ from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry @@ -76,7 +77,8 @@ def setUp(self): self.model = keras.Sequential() self.params = { 'number_of_clusters': 8, - 'cluster_centroids_init': 'density-based' + 'cluster_centroids_init': + cluster_config.CentroidInitialization.DENSITY_BASED } def _build_clustered_layer_model(self, layer): diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index d5385c4d9..37e3681a5 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -18,6 +18,7 @@ from tensorflow.keras import initializers +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry @@ -273,7 +274,8 @@ def from_config(cls, config, custom_objects=None): number_of_clusters = config.pop('number_of_clusters') cluster_centroids_init = config.pop('cluster_centroids_init') config['number_of_clusters'] = number_of_clusters - config['cluster_centroids_init'] = cluster_centroids_init + config['cluster_centroids_init'] = cluster_config.CentroidInitialization( + cluster_centroids_init) from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top layer = deserialize_layer(config.pop('layer'), diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py index 07b10ad5c..68aad7241 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry @@ -30,7 +31,7 @@ layers = keras.layers test = tf.test -layers = keras.layers +CentroidInitialization = cluster_config.CentroidInitialization ClusterRegistry = clustering_registry.ClusteringRegistry ClusteringLookupRegistry = clustering_registry.ClusteringLookupRegistry @@ -55,9 +56,11 @@ def testCannotBeInitializedWithNonLayerObject(self): not an instance of keras.layers.Layer. """ with self.assertRaises(ValueError): - cluster_wrapper.ClusterWeights({ - 'this': 'is not a Layer instance' - }, number_of_clusters=13, cluster_centroids_init='linear') + cluster_wrapper.ClusterWeights( + {'this': 'is not a Layer instance'}, + number_of_clusters=13, + cluster_centroids_init=CentroidInitialization.LINEAR + ) def testCannotBeInitializedWithNonClusterableLayer(self): """ @@ -65,18 +68,22 @@ def testCannotBeInitializedWithNonClusterableLayer(self): custom layer. """ with self.assertRaises(ValueError): - cluster_wrapper.ClusterWeights(NonClusterableLayer(10), - number_of_clusters=13, - cluster_centroids_init='linear') + cluster_wrapper.ClusterWeights( + NonClusterableLayer(10), + number_of_clusters=13, + cluster_centroids_init=CentroidInitialization.LINEAR + ) def testCanBeInitializedWithClusterableLayer(self): """ Verifies that ClusterWeights can be initialized with a built-in clusterable layer. """ - l = cluster_wrapper.ClusterWeights(layers.Dense(10), - number_of_clusters=13, - cluster_centroids_init='linear') + l = cluster_wrapper.ClusterWeights( + layers.Dense(10), + number_of_clusters=13, + cluster_centroids_init=CentroidInitialization.LINEAR + ) self.assertIsInstance(l, cluster_wrapper.ClusterWeights) def testCannotBeInitializedWithNonIntegerNumberOfClusters(self): @@ -85,9 +92,11 @@ def testCannotBeInitializedWithNonIntegerNumberOfClusters(self): provided for the number of clusters. """ with self.assertRaises(ValueError): - cluster_wrapper.ClusterWeights(layers.Dense(10), - number_of_clusters="13", - cluster_centroids_init='linear') + cluster_wrapper.ClusterWeights( + layers.Dense(10), + number_of_clusters="13", + cluster_centroids_init=CentroidInitialization.LINEAR + ) def testCannotBeInitializedWithFloatNumberOfClusters(self): """ @@ -95,9 +104,11 @@ def testCannotBeInitializedWithFloatNumberOfClusters(self): provided for the number of clusters. """ with self.assertRaises(ValueError): - cluster_wrapper.ClusterWeights(layers.Dense(10), - number_of_clusters=13.4, - cluster_centroids_init='linear') + cluster_wrapper.ClusterWeights( + layers.Dense(10), + number_of_clusters=13.4, + cluster_centroids_init=CentroidInitialization.LINEAR + ) @parameterized.parameters( (0), @@ -111,9 +122,11 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo( clusters. """ with self.assertRaises(ValueError): - cluster_wrapper.ClusterWeights(layers.Dense(10), - number_of_clusters=number_of_clusters, - cluster_centroids_init='linear') + cluster_wrapper.ClusterWeights( + layers.Dense(10), + number_of_clusters=number_of_clusters, + cluster_centroids_init=CentroidInitialization.LINEAR + ) def testCanBeInitializedWithAlreadyClusterableLayer(self): """ @@ -121,9 +134,11 @@ def testCanBeInitializedWithAlreadyClusterableLayer(self): layer. """ layer = AlreadyClusterableLayer(10) - l = cluster_wrapper.ClusterWeights(layer, - number_of_clusters=13, - cluster_centroids_init='linear') + l = cluster_wrapper.ClusterWeights( + layer, + number_of_clusters=13, + cluster_centroids_init=CentroidInitialization.LINEAR + ) self.assertIsInstance(l, cluster_wrapper.ClusterWeights) def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self): @@ -131,14 +146,23 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self): Verifies that the ClusterWeights instance created from a layer that has a batch shape attribute, will also have this attribute. """ - l = cluster_wrapper.ClusterWeights(layers.Dense(10, input_shape=(10,)), - number_of_clusters=13, - cluster_centroids_init='linear') + l = cluster_wrapper.ClusterWeights( + layers.Dense(10, input_shape=(10,)), + number_of_clusters=13, + cluster_centroids_init=CentroidInitialization.LINEAR + ) self.assertTrue(hasattr(l, '_batch_input_shape')) # Makes it easier to test all possible parameters combinations. @parameterized.parameters( - *itertools.product(range(2, 16, 4), ('linear', 'random', 'density-based')) + *itertools.product( + range(2, 16, 4), + ( + CentroidInitialization.LINEAR, + CentroidInitialization.RANDOM, + CentroidInitialization.DENSITY_BASED + ) + ) ) def testValuesAreClusteredAfterStripping(self, number_of_clusters, diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py index 5a62edab8..d24f5a17a 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py @@ -18,8 +18,10 @@ import six import tensorflow as tf -k = tf.keras.backend +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +k = tf.keras.backend +CentroidInitialization = cluster_config.CentroidInitialization @six.add_metaclass(abc.ABCMeta) class AbstractCentroidsInitialisation: @@ -187,9 +189,10 @@ class CentroidsInitializerFactory: reflect new methods available. """ _initialisers = { - 'linear': LinearCentroidsInitialisation, - 'random': RandomCentroidsInitialisation, - 'density-based': DensityBasedCentroidsInitialisation + CentroidInitialization.LINEAR : LinearCentroidsInitialisation, + CentroidInitialization.RANDOM : RandomCentroidsInitialisation, + CentroidInitialization.DENSITY_BASED : + DensityBasedCentroidsInitialisation } @classmethod @@ -199,9 +202,11 @@ def init_is_supported(cls, init_method): @classmethod def get_centroid_initializer(cls, init_method): """ - :param init_method: a string representation of the init methods requested - :return: A concrete implementation of AbstractCentroidsInitialisation - :raises: ValueError if the string representation is not recognised + :param init_method: a CentroidInitialization value representing the init + method requested + :return: A concrete implementation of AbstractCentroidsInitialisation + :raises: ValueError if the requested centroid initialization method is not + recognised """ if not cls.init_is_supported(init_method): raise ValueError( diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py index a271dba5a..693276b6d 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids keras = tf.keras @@ -26,6 +27,8 @@ layers = keras.layers test = tf.test +CentroidInitialization = cluster_config.CentroidInitialization + class ClusteringCentroidsTest(test.TestCase, parameterized.TestCase): """Unit tests for the clustering_centroids module.""" @@ -34,9 +37,9 @@ def setUp(self): self.factory = clustering_centroids.CentroidsInitializerFactory @parameterized.parameters( - ('linear'), - ('random'), - ('density-based'), + (CentroidInitialization.LINEAR), + (CentroidInitialization.RANDOM), + (CentroidInitialization.DENSITY_BASED), ) def testExistingInitsAreSupported(self, init_type): """ @@ -48,10 +51,16 @@ def testNonExistingInitIsNotSupported(self): self.assertFalse(self.factory.init_is_supported("DEADBEEF")) @parameterized.parameters( - ('linear', clustering_centroids.LinearCentroidsInitialisation), - ('random', clustering_centroids.RandomCentroidsInitialisation), ( - 'density-based', + CentroidInitialization.LINEAR, + clustering_centroids.LinearCentroidsInitialisation + ), + ( + CentroidInitialization.RANDOM, + clustering_centroids.RandomCentroidsInitialisation + ), + ( + CentroidInitialization.DENSITY_BASED, clustering_centroids.DensityBasedCentroidsInitialisation ), )