diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index e85bfe48e..b53daf5db 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -202,6 +202,7 @@ py_strict_test( # numpy dep1, # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:test_utils", + "//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster", ], ) @@ -228,5 +229,6 @@ py_strict_test( ":cluster", ":cluster_config", # tensorflow dep1, + "//tensorflow_model_optimization/python/core/clustering/keras/experimental", ], ) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py index 3e796f58f..5bfe73aa5 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_distributed_test.py @@ -14,14 +14,18 @@ # ============================================================================== """Distributed clustering test.""" -from absl.testing import parameterized +import itertools +import unittest + import numpy as np import tensorflow as tf - -from tensorflow_model_optimization.python.core.clustering.keras import cluster -from tensorflow_model_optimization.python.core.clustering.keras import cluster_config -from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper -from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils +from absl.testing import parameterized +from tensorflow_model_optimization.python.core.clustering.keras import ( + cluster, cluster_config, cluster_wrapper) +from tensorflow_model_optimization.python.core.clustering.keras.experimental import \ + cluster as experimental_cluster +from tensorflow_model_optimization.python.core.keras import \ + test_utils as keras_test_utils keras = tf.keras CentroidInitialization = cluster_config.CentroidInitialization @@ -30,23 +34,37 @@ def _distribution_strategies(): return [tf.distribute.MirroredStrategy()] +def _clustering_strategies(): + return [ + { + 'number_of_clusters': 2, + 'cluster_centroids_init': CentroidInitialization.LINEAR, + 'preserve_sparsity': False + }, + { + 'number_of_clusters': 3, + 'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS, + 'preserve_sparsity': True + } + ] class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase): """Distributed tests for clustering.""" def setUp(self): super(ClusterDistributedTest, self).setUp() - self.params = { - 'number_of_clusters': 2, - 'cluster_centroids_init': CentroidInitialization.LINEAR - } - @parameterized.parameters(_distribution_strategies()) - def testClusterSimpleDenseModel(self, distribution): + @parameterized.parameters( + *itertools.product( + _distribution_strategies(), + _clustering_strategies() + ) + ) + def testClusterSimpleDenseModel(self, distribution, clustering): """End-to-end test.""" with distribution.scope(): - model = cluster.cluster_weights( - keras_test_utils.build_simple_dense_model(), **self.params) + model = experimental_cluster.cluster_weights( + keras_test_utils.build_simple_dense_model(), **clustering) model.compile( loss='categorical_crossentropy', optimizer='sgd', @@ -64,9 +82,11 @@ def testClusterSimpleDenseModel(self, distribution): weights_as_list = stripped_model.layers[0].kernel.numpy().reshape( -1,).tolist() unique_weights = set(weights_as_list) - self.assertLessEqual(len(unique_weights), self.params['number_of_clusters']) + self.assertLessEqual(len(unique_weights), clustering["number_of_clusters"]) - @parameterized.parameters(_distribution_strategies()) + @parameterized.parameters( + _distribution_strategies() + ) def testAssociationValuesPerReplica(self, distribution): """Verifies that associations of weights are updated per replica.""" assert tf.distribute.get_replica_context() is not None @@ -76,8 +96,9 @@ def testAssociationValuesPerReplica(self, distribution): output_shape = (2, 8) l = cluster_wrapper.ClusterWeights( keras.layers.Dense(8, input_shape=input_shape), - number_of_clusters=self.params['number_of_clusters'], - cluster_centroids_init=self.params['cluster_centroids_init']) + number_of_clusters=2, + cluster_centroids_init=CentroidInitialization.LINEAR + ) l.build(input_shape) clusterable_weights = l.layer.get_clusterable_weights() diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index 56eb23420..b08bacb0c 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test): interpreter.invoke() interpreter.get_tensor(output_index) + @staticmethod + def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name): + layer = stripped_model.layers[layer_nr] + weight = getattr(layer, weight_name) + weights_as_list = weight.numpy().flatten() + nr_of_unique_weights = len(set(weights_as_list)) + return nr_of_unique_weights + @keras_parameterized.run_all_keras_modes def testValuesRemainClusteredAfterTraining(self): """Verifies that training a clustered model does not destroy the clusters.""" @@ -150,73 +158,59 @@ def testValuesRemainClusteredAfterTraining(self): unique_weights = set(weights_as_list) self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) + @keras_parameterized.run_all_keras_modes def testSparsityIsPreservedDuringTraining(self): - # Set a specific random seed to ensure that we get some null weights to - # test sparsity preservation with. + """Set a specific random seed to ensure that we get some null weights + to test sparsity preservation with.""" tf.random.set_seed(1) - - # Verifies that training a clustered model does not destroy the sparsity of - # the weights. + # Verifies that training a clustered model with null weights in it + # does not destroy the sparsity of the weights. original_model = keras.Sequential([ layers.Dense(5, input_shape=(5,)), - layers.Dense(5), + layers.Flatten(), ]) - - # Using a mininum number of centroids to make it more likely that some - # weights will be zero. + # Reset the kernel weights to reflect potential zero drifting of + # the cluster centroids + first_layer_weights = original_model.layers[0].get_weights() + first_layer_weights[0][:][0:2] = 0.0 + first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13] + first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13] + original_model.layers[0].set_weights(first_layer_weights) clustering_params = { - "number_of_clusters": 3, + "number_of_clusters": 6, "cluster_centroids_init": CentroidInitialization.LINEAR, "preserve_sparsity": True } - clustered_model = experimental_cluster.cluster_weights( original_model, **clustering_params) - stripped_model_before_tuning = cluster.strip_clustering(clustered_model) - weights_before_tuning = stripped_model_before_tuning.layers[0].kernel - non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning) - + nr_of_unique_weights_before = self._get_number_of_unique_weights( + stripped_model_before_tuning, 0, 'kernel') clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) - clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1) - + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100) stripped_model_after_tuning = cluster.strip_clustering(clustered_model) weights_after_tuning = stripped_model_after_tuning.layers[0].kernel - non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning) - weights_as_list_after_tuning = weights_after_tuning.numpy().reshape( - -1,).tolist() - unique_weights_after_tuning = set(weights_as_list_after_tuning) - + nr_of_unique_weights_after = self._get_number_of_unique_weights( + stripped_model_after_tuning, 0, 'kernel') + # Check after sparsity-aware clustering, despite zero centroid can drift, + # the final number of unique weights remains the same + self.assertLessEqual(nr_of_unique_weights_after, nr_of_unique_weights_before) # Check that the null weights stayed the same before and after tuning. + # There might be new weights that become zeros but sparsity-aware + # clustering preserves the original null weights in the original positions + # of the weight array self.assertTrue( - np.array_equal(non_zero_weight_indices_before_tuning, - non_zero_weight_indices_after_tuning)) - + np.array_equal(first_layer_weights[0][:][0:2], + weights_after_tuning[:][0:2])) # Check that the number of unique weights matches the number of clusters. self.assertLessEqual( - len(unique_weights_after_tuning), self.params["number_of_clusters"]) - - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) - def testEndToEndSequential(self): - """Test End to End clustering - sequential model.""" - original_model = keras.Sequential([ - layers.Dense(5, input_shape=(5,)), - layers.Dense(5), - ]) - - def clusters_check(stripped_model): - # dense layer - weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() - unique_weights = set(weights_as_list) - self.assertLessEqual( - len(unique_weights), self.params["number_of_clusters"]) - - self.end_to_end_testing(original_model, clusters_check) + nr_of_unique_weights_after, + clustering_params["number_of_clusters"]) @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def testEndToEndFunctional(self): diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index 15583d7d4..5ce2e77cf 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -108,6 +108,9 @@ def __init__(self, # Stores the pairs of weight names and their respective sparsity masks self.sparsity_masks = {} + # Stores the pairs of weight names and the zero centroids + self.zero_idx = {} + # Map weight names to original clusterable weights variables # Those weights will still be updated during backpropagation self.original_clusterable_weights = {} @@ -199,10 +202,32 @@ def build(self, input_shape): pulling_indices, original_weight)) self.sparsity_masks[weight_name] = ( tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32)) + # If the model is pruned (which we suppose), this is approximately zero + self.zero_idx[weight_name] = tf.argmin( + tf.abs(self.cluster_centroids[weight_name]), axis=-1) def update_clustered_weights_associations(self): for weight_name, original_weight in self.original_clusterable_weights.items( ): + + if self.preserve_sparsity: + # Set the smallest centroid to zero to force sparsity + # and avoid extra cluster from forming + zero_idx_mask = ( + tf.cast(tf.math.not_equal( + self.cluster_centroids[weight_name], + self.cluster_centroids[weight_name][self.zero_idx[weight_name]]), + dtype=tf.float32) + ) + self.cluster_centroids[weight_name].assign( + tf.math.multiply(self.cluster_centroids[weight_name], + zero_idx_mask)) + # During training, the original zero weights can drift slightly. + # We want to prevent this by forcing them to stay zero at the places + # where they were originally zero to begin with. + original_weight = tf.math.multiply(original_weight, + self.sparsity_masks[weight_name]) + # Update pulling indices (cluster associations) pulling_indices = ( self.clustering_algorithms[weight_name].get_pulling_indices( @@ -214,11 +239,6 @@ def update_clustered_weights_associations(self): self.clustering_algorithms[weight_name].get_clustered_weight( pulling_indices, original_weight)) - if self.preserve_sparsity: - # Apply the sparsity mask to the clustered weights - clustered_weights = tf.math.multiply(clustered_weights, - self.sparsity_masks[weight_name]) - # Replace the weights with their clustered counterparts setattr(self.layer, weight_name, clustered_weights) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py index 5435dd1e0..23903600e 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/mnist_clustering_test.py @@ -14,10 +14,12 @@ # ============================================================================== """Tests for a simple convnet with clusterable layer on the MNIST dataset.""" +from absl.testing import parameterized import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster tf.random.set_seed(42) @@ -63,7 +65,7 @@ def _train_model(model): model.fit(x_train, y_train, epochs=EPOCHS) -def _cluster_model(model, number_of_clusters): +def _cluster_model(model, number_of_clusters, preserve_sparsity=False): (x_train, y_train), _ = _get_dataset() @@ -71,11 +73,13 @@ def _cluster_model(model, number_of_clusters): 'number_of_clusters': number_of_clusters, 'cluster_centroids_init': - cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS + cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS, + 'preserve_sparsity': + preserve_sparsity, } # Cluster model - clustered_model = cluster.cluster_weights(model, **clustering_params) + clustered_model = experimental_cluster.cluster_weights(model, **clustering_params) # Use smaller learning rate for fine-tuning # clustered model @@ -106,13 +110,27 @@ def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name): return nr_of_unique_weights +def _deepcopy_model(model): + model_copy = keras.models.clone_model(model) + model_copy.set_weights(model.get_weights()) + return model_copy -class FunctionalTest(tf.test.TestCase): +class FunctionalTest(tf.test.TestCase, parameterized.TestCase): - def testMnist(self): - """In this test we test that 'kernel' weights are clustered.""" + def setUp(self): model = _build_model() _train_model(model) + self.model = model + self.dataset = _get_dataset() + + @parameterized.parameters( + (False), + (True), + ) + def testMnist(self, preserve_sparisty): + """In this test we test that 'kernel' weights are clustered.""" + model = self.model + _, (x_test, y_test) = self.dataset # Checks that number of original weights('kernel') is greater than the # number of clusters @@ -123,12 +141,11 @@ def testMnist(self): nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 'bias') self.assertGreater(nr_of_bias_weights, NUMBER_OF_CLUSTERS) - _, (x_test, y_test) = _get_dataset() - results_original = model.evaluate(x_test, y_test) self.assertGreater(results_original[1], 0.8) - clustered_model = _cluster_model(model, NUMBER_OF_CLUSTERS) + model_copy = _deepcopy_model(model) + clustered_model = _cluster_model(model_copy, NUMBER_OF_CLUSTERS, preserve_sparisty) results = clustered_model.evaluate(x_test, y_test)