diff --git a/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py b/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py index c1be4b95b..8e28f9ff5 100644 --- a/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py +++ b/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Module containing clustering code built on Keras abstractions.""" # pylint: disable=g-bad-import-order +from tensorflow_model_optimization.python.core.clustering.keras import experimental + from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering diff --git a/tensorflow_model_optimization/python/core/api/clustering/keras/experimental/__init__.py b/tensorflow_model_optimization/python/core/api/clustering/keras/experimental/__init__.py new file mode 100644 index 000000000..bf3f3241c --- /dev/null +++ b/tensorflow_model_optimization/python/core/api/clustering/keras/experimental/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module containing experimental clustering code built on Keras abstractions.""" +from tensorflow_model_optimization.python.core.clustering.keras.experimental.cluster import cluster_weights diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index 72d4150c6..0aa6fb1cc 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -12,6 +12,7 @@ py_library( srcs_version = "PY3", deps = [ ":cluster", + "//tensorflow_model_optimization/python/core/clustering/keras/experimental", ], ) @@ -90,6 +91,7 @@ py_test( visibility = ["//visibility:public"], deps = [ ":cluster", + "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster", # tensorflow dep1, ], ) @@ -146,6 +148,7 @@ py_test( ":cluster", # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:compat", + "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster", ], ) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py index e8cebe0f9..f2ae873ba 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py @@ -80,8 +80,7 @@ def cluster_weights(to_cluster, ```python clustering_params = { 'number_of_clusters': 8, - 'cluster_centroids_init': - CentroidInitialization.DENSITY_BASED + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED } clustered_model = cluster_weights(original_model, **clustering_params) @@ -92,8 +91,108 @@ def cluster_weights(to_cluster, ```python clustering_params = { 'number_of_clusters': 8, - 'cluster_centroids_init': - CentroidInitialization.DENSITY_BASED + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED + } + + model = keras.Sequential([ + layers.Dense(10, activation='relu', input_shape=(100,)), + cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params) + ]) + ``` + + Arguments: + to_cluster: A single keras layer, list of keras layers, or a + `tf.keras.Model` instance. + number_of_clusters: the number of cluster centroids to form when + clustering a layer/model. For example, if number_of_clusters=8 then only + 8 unique values will be used in each weight array. + cluster_centroids_init: enum value that determines how the cluster + centroids will be initialized. + Can have following values: + 1. RANDOM : centroids are sampled using the uniform distribution + between the minimum and maximum weight values in a given layer + 2. DENSITY_BASED : density-based sampling. First, cumulative + distribution function is built for weights, then y-axis is evenly + spaced into number_of_clusters regions. After this the corresponding x + values are obtained and used to initialize clusters centroids. + 3. LINEAR : cluster centroids are evenly spaced between the minimum + and maximum values of a given weight + preserve_sparsity: optional boolean value that determines whether or not + sparsity preservation will be enforced during training + **kwargs: Additional keyword arguments to be passed to the keras layer. + Ignored when to_cluster is not a keras layer. + + Returns: + Layer or model modified to include clustering related metadata. + + Raises: + ValueError: if the keras layer is unsupported, or the keras model contains + an unsupported layer. + """ + return _cluster_weights(to_cluster, + number_of_clusters, + cluster_centroids_init, + preserve_sparsity=False, + **kwargs) + + +def _cluster_weights(to_cluster, + number_of_clusters, + cluster_centroids_init, + preserve_sparsity, + **kwargs): + """Modify a keras layer or model to be clustered during training (private method). + + This function wraps a keras model or layer with clustering functionality + which clusters the layer's weights during training. For examples, using + this with number_of_clusters equals 8 will ensure that each weight tensor has + no more than 8 unique values. + + Before passing to the clustering API, a model should already be trained and + show some acceptable performance on the testing/validation sets. + + The function accepts either a single keras layer + (subclass of `keras.layers.Layer`), list of keras layers or a keras model + (instance of `keras.models.Model`) and handles them appropriately. + + If it encounters a layer it does not know how to handle, it will throw an + error. While clustering an entire model, even a single unknown layer would + lead to an error. + + Cluster a model: + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': False + } + + clustered_model = cluster_weights(original_model, **clustering_params) + ``` + + Cluster a layer: + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': False + } + + model = keras.Sequential([ + layers.Dense(10, activation='relu', input_shape=(100,)), + cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params) + ]) + ``` + + Cluster a layer with sparsity preservation (experimental): + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': True } model = keras.Sequential([ @@ -110,6 +209,8 @@ def cluster_weights(to_cluster, 8 unique values will be used in each weight array. cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization` instance that determines how the cluster centroids will be initialized. + preserve_sparsity (experimental): optional boolean value that determines whether or not + sparsity preservation will be enforced during training. **kwargs: Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer. @@ -146,6 +247,7 @@ def _add_clustering_wrapper(layer): return cluster_wrapper.ClusterWeights(layer, number_of_clusters, cluster_centroids_init, + preserve_sparsity, **kwargs) def _wrap_list(layers): diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index e1cbbb171..3c677e94f 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -26,6 +26,8 @@ from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.keras import compat +from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster + keras = tf.keras layers = keras.layers test = tf.test @@ -60,10 +62,26 @@ def setUp(self): dtype="float32", ) + self.x_train2 = np.array( + [[0.0, 1.0, 2.0, 3.0, 4.0], [2.0, 0.0, 2.0, 3.0, 4.0], [0.0, 3.0, 2.0, 3.0, 4.0], + [4.0, 1.0, 2.0, 3.0, 4.0], [5.0, 1.0, 2.0, 3.0, 4.0]], + dtype="float32", + ) + + self.y_train2 = np.array( + [[0.0, 1.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0], + [0.0, 1.0, 2.0, 3.0, 4.0], [0.0, 1.0, 2.0, 3.0, 4.0]], + dtype="float32", + ) + def dataset_generator(self): for x, y in zip(self.x_train, self.y_train): yield np.array([x]), np.array([y]) + def dataset_generator2(self): + for x, y in zip(self.x_train2, self.y_train2): + yield np.array([x]), np.array([y]) + def end_to_end_testing(self, original_model, clusters_check=None): """Test End to End clustering.""" @@ -128,6 +146,50 @@ def testValuesRemainClusteredAfterTraining(self): unique_weights = set(weights_as_list) self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) + @keras_parameterized.run_all_keras_modes + def testSparsityIsPreservedDuringTraining(self): + """ Set a specific random seed to ensure that we get some null weights to test sparsity preservation with. """ + tf.random.set_seed(1) + + """Verifies that training a clustered model does not destroy the sparsity of the weights.""" + original_model = keras.Sequential([ + layers.Dense(5, input_shape=(5,)), + layers.Dense(5), + ]) + + """Using a mininum number of centroids to make it more likely that some weights will be zero.""" + clustering_params = { + "number_of_clusters": 3, + "cluster_centroids_init": CentroidInitialization.LINEAR, + "preserve_sparsity": True + } + + clustered_model = experimental_cluster.cluster_weights(original_model, **clustering_params) + + stripped_model_before_tuning = cluster.strip_clustering(clustered_model) + weights_before_tuning = stripped_model_before_tuning.get_weights()[0] + non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning) + + clustered_model.compile( + loss=keras.losses.categorical_crossentropy, + optimizer="adam", + metrics=["accuracy"], + ) + clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1) + + stripped_model_after_tuning = cluster.strip_clustering(clustered_model) + weights_after_tuning = stripped_model_after_tuning.get_weights()[0] + non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning) + weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist() + unique_weights_after_tuning = set(weights_as_list_after_tuning) + + """Check that the null weights stayed the same before and after tuning.""" + self.assertTrue(np.array_equal(non_zero_weight_indices_before_tuning, + non_zero_weight_indices_after_tuning)) + + """Check that the number of unique weights matches the number of clusters.""" + self.assertLessEqual(len(unique_weights_after_tuning), self.params["number_of_clusters"]) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def testEndToEndSequential(self): """Test End to End clustering - sequential model.""" diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 0eb335940..ed239f50a 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -26,6 +26,8 @@ from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry +from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster + keras = tf.keras errors_impl = tf.errors layers = keras.layers @@ -112,6 +114,18 @@ def testClusterKerasClusterableLayer(self): self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer) + @keras_parameterized.run_all_keras_modes + def testClusterKerasClusterableLayerWithSparsityPreservation(self): + """ + Verifies that a built-in keras layer marked as clusterable is being + clustered correctly when sparsity preservation is enabled. + """ + preserve_sparsity_params = { 'preserve_sparsity': True } + params = { **self.params, **preserve_sparsity_params } + wrapped_layer = experimental_cluster.cluster_weights(self.keras_clusterable_layer, **params) + + self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer) + @keras_parameterized.run_all_keras_modes def testClusterKerasNonClusterableLayer(self): """ @@ -164,6 +178,22 @@ def testClusterCustomClusterableLayer(self): self.assertEqual([('kernel', wrapped_layer.layer.kernel)], wrapped_layer.layer.get_clusterable_weights()) + @keras_parameterized.run_all_keras_modes + def testClusterCustomClusterableLayerWithSparsityPreservation(self): + """ + Verifies that a custom clusterable layer is being clustered correctly + when sparsity preservation is enabled. + """ + preserve_sparsity_params = { 'preserve_sparsity': True } + params = { **self.params, **preserve_sparsity_params } + wrapped_layer = experimental_cluster.cluster_weights(self.custom_clusterable_layer, **params) + self.model.add(wrapped_layer) + self.model.build(input_shape=(10, 1)) + + self._validate_clustered_layer(self.custom_clusterable_layer, wrapped_layer) + self.assertEqual([('kernel', wrapped_layer.layer.kernel)], + wrapped_layer.layer.get_clusterable_weights()) + def testClusterCustomNonClusterableLayer(self): """ Verifies that attempting to cluster a custom non-clusterable layer raises @@ -193,6 +223,22 @@ def testClusterSequentialModelSelectively(self): self.assertIsInstance(clustered_model.layers[0], cluster_wrapper.ClusterWeights) self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights) + @keras_parameterized.run_all_keras_modes + def testClusterSequentialModelSelectivelyWithSparsityPreservation(self): + """ + Verifies that layers within a sequential model can be clustered + selectively when sparsity preservation is enabled. + """ + preserve_sparsity_params = { 'preserve_sparsity': True } + params = { **self.params, **preserve_sparsity_params } + clustered_model = keras.Sequential() + clustered_model.add(experimental_cluster.cluster_weights(self.keras_clusterable_layer, **params)) + clustered_model.add(self.keras_clusterable_layer) + clustered_model.build(input_shape=(1, 10)) + + self.assertIsInstance(clustered_model.layers[0], cluster_wrapper.ClusterWeights) + self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights) + @keras_parameterized.run_all_keras_modes def testClusterFunctionalModelSelectively(self): """ @@ -209,6 +255,24 @@ def testClusterFunctionalModelSelectively(self): self.assertIsInstance(clustered_model.layers[2], cluster_wrapper.ClusterWeights) self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights) + @keras_parameterized.run_all_keras_modes + def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self): + """ + Verifies that layers within a functional model can be clustered + selectively when sparsity preservation is enabled. + """ + preserve_sparsity_params = { 'preserve_sparsity': True } + params = { **self.params, **preserve_sparsity_params } + i1 = keras.Input(shape=(10,)) + i2 = keras.Input(shape=(10,)) + x1 = experimental_cluster.cluster_weights(layers.Dense(10), **params)(i1) + x2 = layers.Dense(10)(i2) + outputs = layers.Add()([x1, x2]) + clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs) + + self.assertIsInstance(clustered_model.layers[2], cluster_wrapper.ClusterWeights) + self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights) + @keras_parameterized.run_all_keras_modes def testClusterModelValidLayersSuccessful(self): """ @@ -227,6 +291,26 @@ def testClusterModelValidLayersSuccessful(self): for layer, clustered_layer in zip(model.layers, clustered_model.layers): self._validate_clustered_layer(layer, clustered_layer) + @keras_parameterized.run_all_keras_modes + def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self): + """ + Verifies that clustering a sequential model results in all clusterable + layers within the model being clustered when sparsity preservation is enabled. + """ + preserve_sparsity_params = { 'preserve_sparsity': True } + params = { **self.params, **preserve_sparsity_params } + model = keras.Sequential([ + self.keras_clusterable_layer, + self.keras_non_clusterable_layer, + self.custom_clusterable_layer + ]) + clustered_model = experimental_cluster.cluster_weights(model, **params) + clustered_model.build(input_shape=(1, 28, 28, 1)) + + self.assertEqual(len(model.layers), len(clustered_model.layers)) + for layer, clustered_layer in zip(model.layers, clustered_model.layers): + self._validate_clustered_layer(layer, clustered_layer) + def testClusterModelUnsupportedKerasLayerRaisesError(self): """ Verifies that attempting to cluster a model that contains an unsupported diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index 942096268..f3e12d151 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -57,6 +57,7 @@ def __init__(self, layer, number_of_clusters, cluster_centroids_init, + preserve_sparsity=False, **kwargs): if not isinstance(layer, Layer): raise ValueError( @@ -90,9 +91,11 @@ def __init__(self, ) ) - if number_of_clusters <= 1: + limit_number_of_clusters = 2 if preserve_sparsity else 1 + if number_of_clusters <= limit_number_of_clusters: raise ValueError( - "number_of_clusters must be greater than 1. Given: {}".format( + "number_of_clusters must be greater than {}. Given: {}".format( + limit_number_of_clusters, number_of_clusters ) ) @@ -105,6 +108,12 @@ def __init__(self, # The number of cluster centroids self.number_of_clusters = number_of_clusters + # Whether to apply sparsity preservation or not + self.preserve_sparsity = preserve_sparsity + + # Stores the pairs of weight names and their respective sparsity masks + self.sparsity_masks = {} + # Stores the pairs of weight names and references to their tensors self.ori_weights_vars_tf = {} @@ -187,7 +196,7 @@ def build(self, input_shape): centroid_initializer = clustering_centroids.CentroidsInitializerFactory.\ get_centroid_initializer( self.cluster_centroids_init - )(weight, self.number_of_clusters) + )(weight, self.number_of_clusters, self.preserve_sparsity) cluster_centroids = centroid_initializer.get_cluster_centroids() @@ -229,6 +238,16 @@ def build(self, input_shape): ) ) + if self.preserve_sparsity: + # Get the clustered weights + clustered_weights = self.clustering_impl[weight_name].get_clustered_weight(pulling_indices) + + # Create the sparsity mask + sparsity_mask = tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32) + + # Store the sparsity mask for training + self.sparsity_masks[weight_name] = sparsity_mask + # We store these pairs to easily update this variables later on self.ori_weights_vars_tf[weight_name] = self.add_weight( '{}{}'.format('ori_weights_vars_tf_', weight_name), @@ -241,13 +260,21 @@ def build(self, input_shape): ) # We use currying here to get an updater which can be triggered at any time - # in future and it would return the latest version of clustered weights + # in the future and it would return the latest version of clustered weights def get_updater(for_weight_name): def fn(): # Get the clustered weights pulling_indices = self.pulling_indices_tf[for_weight_name] clustered_weights = self.clustering_impl[for_weight_name].\ get_clustered_weight(pulling_indices) + + if self.preserve_sparsity: + # Get the sparsity mask + sparsity_mask = self.sparsity_masks[for_weight_name] + + # Apply the sparsity mask to the clustered weights + clustered_weights = tf.math.multiply(clustered_weights, sparsity_mask) + return clustered_weights return fn @@ -279,10 +306,18 @@ def call(self, inputs): pulling_indices.dtype )) + # Get the clustered weights clustered_weights = self.clustering_impl[weight_name].\ get_clustered_weight_forward(pulling_indices,\ self.ori_weights_vars_tf[weight_name]) + if self.preserve_sparsity: + # Get the sparsity mask + sparsity_mask = self.sparsity_masks[weight_name] + + # Apply the sparsity mask to the clustered weights + clustered_weights = tf.math.multiply(clustered_weights, sparsity_mask) + # Replace the weights with their clustered counterparts setattr(self.layer, weight_name, clustered_weights) @@ -295,7 +330,8 @@ def get_config(self): base_config = super(ClusterWeights, self).get_config() config = { 'number_of_clusters': self.number_of_clusters, - 'cluster_centroids_init': self.cluster_centroids_init + 'cluster_centroids_init': self.cluster_centroids_init, + 'preserve_sparsity': self.preserve_sparsity } return dict(list(base_config.items()) + list(config.items())) @@ -305,9 +341,11 @@ def from_config(cls, config, custom_objects=None): number_of_clusters = config.pop('number_of_clusters') cluster_centroids_init = config.pop('cluster_centroids_init') + preserve_sparsity = config.pop('preserve_sparsity') config['number_of_clusters'] = number_of_clusters config['cluster_centroids_init'] = cluster_config.CentroidInitialization( cluster_centroids_init) + config['preserve_sparsity'] = preserve_sparsity from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top layer = deserialize_layer(config.pop('layer'), diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py index 7b7a81ba8..0c2f27b8f 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py @@ -131,6 +131,25 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo( cluster_centroids_init=CentroidInitialization.LINEAR ) + @parameterized.parameters( + (0), + (2), + (-32) + ) + def testCannotBeInitializedWithSparsityPreservationAndNumberOfClustersLessThanThree( + self, number_of_clusters): + """ + Verifies that ClusterWeights cannot be initialized with less than three + clusters when sparsity preservation is enabled. + """ + with self.assertRaises(ValueError): + cluster_wrapper.ClusterWeights( + layers.Dense(10), + number_of_clusters=number_of_clusters, + cluster_centroids_init=CentroidInitialization.LINEAR, + preserve_sparsity=True + ) + def testCanBeInitializedWithAlreadyClusterableLayer(self): """ Verifies that ClusterWeights can be initialized with a custom clusterable diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py index 7f32090f2..55c14b02e 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py @@ -28,44 +28,126 @@ class AbstractCentroidsInitialisation: Abstract base class for implementing different cluster centroid initialisation algorithms. Must be initialised with a reference to the weights and implement the single method below. + + Optionally, zero-centroid initialization (used for sparsity-aware clustering) + can be enforced by setting the preserve_sparsity option in the clustering parameters. + The procedure is the following: + 1. First, one centroid is set to zero explicitly + 2. The zero-point centroid divides the weights into two intervals: positive + and negative + 3. The remaining centroids are proportionally allocated to the two intervals + 4. For each interval (positive and negative), the standard initialization is used """ - def __init__(self, weights, number_of_clusters): + def __init__(self, weights, number_of_clusters, preserve_sparsity=False): self.weights = weights self.number_of_clusters = number_of_clusters + self.preserve_sparsity = preserve_sparsity @abc.abstractmethod - def get_cluster_centroids(self): + def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): pass + def _regular_clustering(self): + # Regular clustering calculates the centroids using all the weights + centroids = self._calculate_centroids_for_interval(self.weights, self.number_of_clusters) + cluster_centroids = tf.reshape(centroids, (self.number_of_clusters,)) + return cluster_centroids + + def _zero_centroid_initialization(self): + # The zero-centroid sparsity preservation technique works as follows: + # + # 1. First, one centroid is set to zero explicitly + # 2. The zero-point centroid divides the weights into two intervals: positive and negative + # 3. The remaining centroids are proportionally allocated to the two intervals + # 4. For each interval (positive and negative), the standard initialization is used + # + # This method is also referred to as sparsity-aware centroid initialization. + + # Zero-point centroid + zero_centroid = tf.zeros(shape=(1,)) + + # Get the negative weights + negative_weights = tf.boolean_mask(self.weights, tf.math.less(self.weights, 0)) + negative_weights_count = tf.size(negative_weights) + + # Get the positive weights + positive_weights = tf.boolean_mask(self.weights, tf.math.greater(self.weights, 0)) + positive_weights_count = tf.size(positive_weights) + + # Get the number of non-zero weights + non_zero_weights_count = negative_weights_count + positive_weights_count + + if tf.math.equal(non_zero_weights_count, 0): + # No non-zero weights available, simply return the zero-centroid + return zero_centroid + + # Reduce the number of clusters by one to allow room for the zero-point centroid + number_of_non_zero_clusters = self.number_of_clusters - 1 + + # Split the non-zero clusters proportionally among negative and positive weights + negative_weights_ratio = negative_weights_count / non_zero_weights_count + number_of_negative_clusters = tf.cast(tf.math.round(number_of_non_zero_clusters * negative_weights_ratio), dtype=tf.int64) + number_of_positive_clusters = number_of_non_zero_clusters - number_of_negative_clusters + + # Calculate the negative centroids + negative_cluster_centroids = self._calculate_centroids_for_interval(negative_weights, number_of_negative_clusters) + + # Calculate the positive centroids + positive_cluster_centroids = self._calculate_centroids_for_interval(positive_weights, number_of_positive_clusters) + + # Put all the centroids together: negative, zero, positive + centroids = tf.concat([negative_cluster_centroids, zero_centroid, positive_cluster_centroids], axis=0) + + return centroids + + def get_cluster_centroids(self): + # Check whether sparsity preservation should be enforced + if self.preserve_sparsity: + # Apply the zero-centroid sparsity preservation technique + return self._zero_centroid_initialization() + else: + # Perform regular clustering + return self._regular_clustering() + class LinearCentroidsInitialisation(AbstractCentroidsInitialisation): """ Spaces cluster centroids evenly in the interval [min(weights), max(weights)] """ - def get_cluster_centroids(self): - weight_min = tf.reduce_min(self.weights) - weight_max = tf.reduce_max(self.weights) + def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): + if tf.math.less_equal(number_of_clusters_for_interval, 0): + # Return an empty array of centroids + return tf.constant([]) + + weight_min = tf.reduce_min(weight_interval) + weight_max = tf.reduce_max(weight_interval) cluster_centroids = tf.linspace(weight_min, weight_max, - self.number_of_clusters) + number_of_clusters_for_interval) + return cluster_centroids + class KmeansPlusPlusCentroidsInitialisation(AbstractCentroidsInitialisation): """ Cluster centroids based on kmeans++ algorithm """ - def get_cluster_centroids(self): - weights = tf.reshape(self.weights, [-1, 1]) + def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): + if tf.math.less_equal(number_of_clusters_for_interval, 0): + # Return an empty array of centroids + return tf.constant([]) + weights = tf.reshape(weight_interval, [-1, 1]) cluster_centroids = clustering_ops.kmeans_plus_plus_initialization(weights, - self.number_of_clusters, + number_of_clusters_for_interval, seed=9, num_retries_per_sample=-1) - return cluster_centroids + return tf.reshape(cluster_centroids, [number_of_clusters_for_interval]) + class RandomCentroidsInitialisation(AbstractCentroidsInitialisation): """ @@ -73,13 +155,13 @@ class RandomCentroidsInitialisation(AbstractCentroidsInitialisation): [min(weights), max(weights)] """ - def get_cluster_centroids(self): - weight_min = tf.reduce_min(self.weights) - weight_max = tf.reduce_max(self.weights) - cluster_centroids = tf.random.uniform(shape=(self.number_of_clusters,), + def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): + weight_min = tf.reduce_min(weight_interval) + weight_max = tf.reduce_max(weight_interval) + cluster_centroids = tf.random.uniform(shape=(number_of_clusters_for_interval,), minval=weight_min, maxval=weight_max, - dtype=self.weights.dtype) + dtype=weight_interval.dtype) return cluster_centroids @@ -147,48 +229,72 @@ class DensityBasedCentroidsInitialisation(AbstractCentroidsInitialisation): centroid """ - def get_cluster_centroids(self): - weight_min = tf.reduce_min(self.weights) - weight_max = tf.reduce_max(self.weights) - # Calculating interpolation nodes, +/- 0.01 is introduced to guarantee that - # CDF will have 0 and 1 and the first and last value respectively. - # The value 30 is a guess. We just need a sufficiently large number here - # since we are going to interpolate values linearly anyway and the initial - # guess will drift away. For these reasons we do not really - # care about the granularity of the lookup. - cdf_x_grid = tf.linspace(weight_min - 0.01, weight_max + 0.01, 30) - - f = TFCumulativeDistributionFunction(weights=self.weights) - - cdf_values = k.map_fn(f.get_cdf_value, cdf_x_grid) - - probability_space = tf.linspace(0 + 0.01, 1, self.number_of_clusters) - - # Use upper-bound algorithm to find the appropriate bounds - matching_indices = tf.searchsorted(sorted_sequence=cdf_values, - values=probability_space, - side='right') - - # Interpolate linearly between every found indices I at position using I at - # pos n-1 as a second point. The value of x is a new cluster centroid + def _get_centroids(self, cdf_x_grid, cdf_values, matching_indices): + # Interpolate linearly between every found index using 'i' as the current position + # and 'i-1' as a second point. The value of 'x' is a new cluster centroid def get_single_centroid(i): i_clipped = tf.minimum(i, tf.size(cdf_values) - 1) i_previous = tf.maximum(0, i_clipped - 1) - s = TFLinearEquationSolver(x1=cdf_x_grid[i_clipped], - y1=cdf_values[i_clipped], - x2=cdf_x_grid[i_previous], - y2=cdf_values[i_previous]) - - y = cdf_values[i_clipped] + x1 = cdf_x_grid[i_clipped] + x2 = cdf_x_grid[i_previous] + y1 = cdf_values[i_clipped] + y2 = cdf_values[i_previous] + + # Check whether interpolation is possible + if y2 == y1: + # If there's no delta y it doesn't make sense to try to interpolate + # the value of x, so just take the lower bound instead + single_centroid = x1 + else: + # Interpolate linearly + s = TFLinearEquationSolver(x1=x1, y1=y1, x2=x2, y2=y2) + single_centroid = s.solve_for_x(y1) - single_centroid = s.solve_for_x(y) return single_centroid centroids = k.map_fn(get_single_centroid, matching_indices, dtype=tf.float32) - cluster_centroids = tf.reshape(centroids, (self.number_of_clusters,)) + return centroids + + def _calculate_centroids_for_interval(self, weight_interval, number_of_clusters_for_interval): + if tf.math.less_equal(number_of_clusters_for_interval, 0): + # Return an empty array of centroids + return tf.constant([]) + + # Get the limits of the weight interval + weights_min = tf.reduce_min(weight_interval) + weights_max = tf.reduce_max(weight_interval) + + # Calculate the gap to put at either side of the given interval + weights_gap = 0.01 if not self.preserve_sparsity \ + else tf.minimum(0.01, + tf.minimum(tf.math.abs(weights_min), + tf.math.abs(weights_max)) / 2) + + # Calculating the interpolation nodes for the given weights. + # A gap is introduced on either side to guarantee that the CDF will have + # 0 and 1 as the first and last value respectively. + # The value 30 is a guess, we just need a sufficiently large number here + # since we are going to interpolate values linearly anyway and the initial + # guess will drift away. For these reasons we do not really + # care about the granularity of the lookup + cdf_x_grid = tf.linspace(weights_min - weights_gap, + weights_max + weights_gap, + 30) + + # Calculate the centroids within the given interval + cdf = TFCumulativeDistributionFunction(weights=weight_interval) + cdf_values = k.map_fn(cdf.get_cdf_value, cdf_x_grid) + probability_space = tf.linspace(0 + 0.01, 1, number_of_clusters_for_interval) + matching_indices = tf.searchsorted(sorted_sequence=cdf_values, + values=probability_space, + side='right') + + centroids = self._get_centroids(cdf_x_grid, cdf_values, matching_indices) + cluster_centroids = tf.reshape(centroids, (number_of_clusters_for_interval,)) + return cluster_centroids diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py index a0fc2052c..2dcf69fb2 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py @@ -157,6 +157,57 @@ def testCDFValues(self, weights, point, probability): K.batch_get_value([cdf_calc.get_cdf_value(point)])[0] ) + @parameterized.parameters( + ( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + 5, + [0., 2.5, 5., 7.5, 10.] + ), + ( + [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], + 3, + [0., 1.75, 3.5] + ), + ( + [-3., -2., -1., 0., 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.], + 6, + [-3., -0.6, 1.8, 4.2, 6.6, 9.] + ) + ) + def testLinearClusterCentroids(self, weights, number_of_clusters, centroids): + dbci = clustering_centroids.LinearCentroidsInitialisation( + weights, + number_of_clusters + ) + calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] + self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + + @parameterized.parameters( + ( + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + 5, + [0., 1., 4., 7., 10.] + ), + ( + [0., 1., 2., 3., 3.1, 3.2, 3.3, 3.4, 3.5], + 3, + [0., 1., 3.5] + ), + ( + [-3., -2., -1., 0., 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.], + 6, + [-3., 0., 1.1, 3.7333333, 6.366666, 9.] + ) + ) + def testLinearClusterCentroidsWithSparsityPreservation(self, weights, number_of_clusters, centroids): + dbci = clustering_centroids.LinearCentroidsInitialisation( + weights, + number_of_clusters, + True + ) + calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] + self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + @parameterized.parameters( ( [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], @@ -174,7 +225,7 @@ def testCDFValues(self, weights, point, probability): [0.3010345, 5.2775865, 9.01] ) ) - def testClusterCentroids(self, weights, number_of_clusters, centroids): + def testDensityBasedClusterCentroids(self, weights, number_of_clusters, centroids): dbci = clustering_centroids.DensityBasedCentroidsInitialisation( weights, number_of_clusters @@ -182,6 +233,66 @@ def testClusterCentroids(self, weights, number_of_clusters, centroids): calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + @parameterized.parameters( + ( + [0., -1., -2., -3., -4., -5., -6.], + 4, + [-5.836897, -2.8941379, -0.98999995, 0.] + ), + ( + [0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], + 5, + [0., 1.2665517, 4.032069, 7.0741386, 9.01] + ), + ( + [-4., -3., -2., -1., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7.], + 6, + [-3.9058623, -0.99, 0., 1.1975863, 4.103793, 7.01] + ), + ( + [0., 1., 2., 3., -3.1, -3.2, -3.3, -0.005, 3.5], + 3, + [-3.1887069, 0., 1.0768965] + ), + ( + [0., 0., 0., 0.], + 2, + [0.] + ) + ) + def testDensityBasedClusterCentroidsWithSparsityPreservation( + self, weights, number_of_clusters, centroids): + dbci = clustering_centroids.DensityBasedCentroidsInitialisation( + weights, + number_of_clusters, + True + ) + calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] + self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + + @parameterized.parameters( + ( + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + 5 + ), + ( + [0., 1., 2., 3., 3.1, 3.2, 3.3, 3.4, 3.5], + 3 + ), + ( + [-3., -2., -1., 0., 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.], + 6 + ) + ) + def testRandomClusterCentroidsWithSparsityPreservation(self, weights, number_of_clusters): + dbci = clustering_centroids.RandomCentroidsInitialisation( + weights, + number_of_clusters, + True + ) + calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] + self.assertContainsSubset([0.], calc_centroids, msg="The centroids must include the zero-point cluster") + @parameterized.parameters( ( [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], @@ -199,7 +310,7 @@ def testClusterCentroids(self, weights, number_of_clusters, centroids): [6., 1., 8.] ) ) - def testKmeanPlusPlusValues(self, weights, number_of_clusters, centroids): + def testKmeansPlusPlusClusterCentroids(self, weights, number_of_clusters, centroids): kmci = clustering_centroids.KmeansPlusPlusCentroidsInitialisation( weights, number_of_clusters @@ -207,5 +318,31 @@ def testKmeanPlusPlusValues(self, weights, number_of_clusters, centroids): calc_centroids = K.batch_get_value([kmci.get_cluster_centroids()])[0] self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + @parameterized.parameters( + ( + [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], + 5, + [0., 3., 1., 2., 3.3] + ), + ( + [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], + 3, + [0., 3., 1.] + ), + ( + [-4., -3., -2., -1., 0., 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.], + 6, + [-2., -4., 0., 5.5, 2.2, 8.8] + ) + ) + def testKmeansPlusPlusClusterCentroidsWithSparsityPreservation(self, weights, number_of_clusters, centroids): + kmci = clustering_centroids.KmeansPlusPlusCentroidsInitialisation( + weights, + number_of_clusters, + True + ) + calc_centroids = K.batch_get_value([kmci.get_cluster_centroids()])[0] + self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + if __name__ == '__main__': test.main() diff --git a/tensorflow_model_optimization/python/core/clustering/keras/experimental/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/experimental/BUILD new file mode 100644 index 000000000..1950d14fe --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/experimental/BUILD @@ -0,0 +1,28 @@ +package(default_visibility = [ + "//tensorflow_model_optimization:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "experimental", + srcs = [ + "__init__.py", + ], + srcs_version = "PY3", + deps = [ + ":cluster", + ], +) + +py_library( + name = "cluster", + srcs = ["cluster.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_model_optimization/python/core/clustering/keras:cluster", + ], +) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/experimental/__init__.py b/tensorflow_model_optimization/python/core/clustering/keras/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py new file mode 100644 index 000000000..f3124aacb --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/experimental/cluster.py @@ -0,0 +1,109 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental clustering API functions for Keras models.""" + +from tensorflow_model_optimization.python.core.clustering.keras.cluster import _cluster_weights + + +def cluster_weights(to_cluster, + number_of_clusters, + cluster_centroids_init, + preserve_sparsity, + **kwargs): + """Modify a keras layer or model to be clustered during training (experimental). + + This function wraps a keras model or layer with clustering functionality + which clusters the layer's weights during training. For examples, using + this with number_of_clusters equals 8 will ensure that each weight tensor has + no more than 8 unique values. + + Before passing to the clustering API, a model should already be trained and + show some acceptable performance on the testing/validation sets. + + The function accepts either a single keras layer + (subclass of `keras.layers.Layer`), list of keras layers or a keras model + (instance of `keras.models.Model`) and handles them appropriately. + + If it encounters a layer it does not know how to handle, it will throw an + error. While clustering an entire model, even a single unknown layer would + lead to an error. + + Cluster a model: + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': False + } + + clustered_model = cluster_weights(original_model, **clustering_params) + ``` + + Cluster a layer: + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': False + } + + model = keras.Sequential([ + layers.Dense(10, activation='relu', input_shape=(100,)), + cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params) + ]) + ``` + + Cluster a layer with sparsity preservation: + + ```python + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED, + 'preserve_sparsity': True + } + + model = keras.Sequential([ + layers.Dense(10, activation='relu', input_shape=(100,)), + cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params) + ]) + ``` + + Arguments: + to_cluster: A single keras layer, list of keras layers, or a + `tf.keras.Model` instance. + number_of_clusters: the number of cluster centroids to form when + clustering a layer/model. For example, if number_of_clusters=8 then only + 8 unique values will be used in each weight array. + cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization` + instance that determines how the cluster centroids will be initialized. + preserve_sparsity: optional boolean value that determines whether or not + sparsity preservation will be enforced during training. + **kwargs: Additional keyword arguments to be passed to the keras layer. + Ignored when to_cluster is not a keras layer. + + Returns: + Layer or model modified to include clustering related metadata. + + Raises: + ValueError: if the keras layer is unsupported, or the keras model contains + an unsupported layer. + """ + return _cluster_weights(to_cluster, + number_of_clusters, + cluster_centroids_init, + preserve_sparsity, + **kwargs)