From 51d4c2228336f1ded6631c614c8802c77f984ee2 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Mon, 17 Aug 2020 17:34:15 +0100 Subject: [PATCH] Enable differentiable training and update cluster indices --- .../python/core/clustering/keras/cluster.py | 2 +- .../core/clustering/keras/cluster_wrapper.py | 52 ++++++++++---- .../clustering/keras/cluster_wrapper_test.py | 69 +++++++++++++++++-- .../clustering/keras/clustering_registry.py | 32 ++++++++- .../keras/clustering_registry_test.py | 65 +++++++++++++++++ 5 files changed, 199 insertions(+), 21 deletions(-) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py index 849d583f7..bdd344c5e 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py @@ -48,7 +48,7 @@ def cluster_scope(): """ return CustomObjectScope( { - 'ClusterWeights' : cluster_wrapper.ClusterWeights + 'ClusterWeights': cluster_wrapper.ClusterWeights } ) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index c5215ddd4..99db6cabf 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -27,6 +27,7 @@ k = keras.backend Layer = keras.layers.Layer Wrapper = keras.layers.Wrapper +CentroidInitialization = cluster_config.CentroidInitialization class ClusterWeights(Wrapper): @@ -105,7 +106,7 @@ def __init__(self, self.number_of_clusters = number_of_clusters # Stores the pairs of weight names and references to their tensors - self.clustered_vars = [] + self.ori_weights_vars_tf = {} # Stores references to class instances that implement different clustering # behaviour for different shapes of objects @@ -227,24 +228,33 @@ def build(self, input_shape): ) # We store these pairs to easily update this variables later on - self.clustered_vars.append((weight_name, weight)) + self.ori_weights_vars_tf[weight_name] = self.add_weight( + 'ori_weights_vars_tf', + shape=weight.shape, + dtype=weight.dtype, + trainable=True, + initializer=initializers.Constant( + value=k.batch_get_value([weight])[0] + ) + ) # We use currying here to get an updater which can be triggered at any time # in future and it would return the latest version of clustered weights def get_updater(for_weight_name): def fn(): - return self.clustering_impl[for_weight_name].get_clustered_weight( - self.pulling_indices_tf[for_weight_name] - ) + # Get the clustered weights + pulling_indices = self.pulling_indices_tf[for_weight_name] + clustered_weights = self.clustering_impl[for_weight_name].\ + get_clustered_weight(pulling_indices) + return clustered_weights return fn # This will allow us to restore the order of weights later # This loop stores pairs of weight names and how to restore them - for ct, weight in enumerate(self.layer.weights): name = self._weight_name(weight.name) - full_name = self.layer.name + "/" + name + full_name = '{}{}{}'.format(self.layer.name, '/', name) if ct in self.gone_variables: # Again, not sure if this is needed weight_name = clusterable_weights_to_variables[name] @@ -253,14 +263,26 @@ def fn(): self.restore.append((name, full_name, weight)) def call(self, inputs): + # In the forward pass, we need to update the cluster associations manually + # since they are integers and not differentiable. Gradients won't flow back + # through tf.argmin # Go through all tensors and replace them with their clustered copies. - for weight_name, _ in self.clustered_vars: - setattr( - self.layer, weight_name, - self.clustering_impl[weight_name].get_clustered_weight( - self.pulling_indices_tf[weight_name] - ) - ) + for weight_name in self.ori_weights_vars_tf: + pulling_indices = self.pulling_indices_tf[weight_name] + + # Update cluster associations + pulling_indices.assign(tf.dtypes.cast( + self.clustering_impl[weight_name].\ + get_pulling_indices(self.ori_weights_vars_tf[weight_name]), + pulling_indices.dtype + )) + + clustered_weights = self.clustering_impl[weight_name].\ + get_clustered_weight_forward(pulling_indices,\ + self.ori_weights_vars_tf[weight_name]) + + # Replace the weights with their clustered counterparts + setattr(self.layer, weight_name, clustered_weights) return self.layer.call(inputs) @@ -271,7 +293,7 @@ def get_config(self): base_config = super(ClusterWeights, self).get_config() config = { 'number_of_clusters': self.number_of_clusters, - 'cluster_centroids_init': self.cluster_centroids_init, + 'cluster_centroids_init': self.cluster_centroids_init } return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py index 6a1bc78db..510230863 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py @@ -15,7 +15,6 @@ """Tests for keras ClusterWeights wrapper API.""" import itertools -import numpy as np import tensorflow as tf from absl.testing import parameterized @@ -155,9 +154,9 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self): *itertools.product( range(2, 16, 4), ( - CentroidInitialization.LINEAR, - CentroidInitialization.RANDOM, - CentroidInitialization.DENSITY_BASED + CentroidInitialization.LINEAR, + CentroidInitialization.RANDOM, + CentroidInitialization.DENSITY_BASED ) ) ) @@ -194,5 +193,67 @@ def testValuesAreClusteredAfterStripping(self, self.assertEqual(stripped_model.layers[0].weights[0].name, weights_name) self.assertEqual(stripped_model.layers[0].weights[1].name, bias_name) + def testClusterReassociation(self): + """ + Verifies that the association of weights to cluster centroids are updated + every iteration. + """ + + # Create a dummy layer for this test + input_shape = (1, 2,) + l = cluster_wrapper.ClusterWeights( + keras.layers.Dense(8, input_shape=input_shape), + number_of_clusters=2, + cluster_centroids_init=CentroidInitialization.LINEAR + ) + # Build a layer with the given shape + l.build(input_shape) + + # Get name of the clusterable weights + clusterable_weights = l.layer.get_clusterable_weights() + self.assertEqual(len(clusterable_weights), 1) + weights_name = clusterable_weights[0][0] + self.assertEqual(weights_name, 'kernel') + # Get cluster centroids + centroids = l.cluster_centroids_tf[weights_name] + + # Calculate some statistics of the weights to set the centroids later on + mean_weight = tf.reduce_mean(l.layer.kernel) + min_weight = tf.reduce_min(l.layer.kernel) + max_weight = tf.reduce_max(l.layer.kernel) + max_dist = max_weight - min_weight + + def assert_all_weights_associated(weights, centroid_index): + """Helper function to make sure that all weights are associated with one + centroid.""" + all_associated = tf.reduce_all( + tf.equal( + weights, + tf.constant(centroids[centroid_index], shape=weights.shape) + ) + ) + self.assertTrue(all_associated) + + # Set centroids so that all weights should be re-associated with centroid 0 + centroids[0].assign(mean_weight) + centroids[1].assign(mean_weight + 2.0 * max_dist) + + # Update associations of weights to centroids + l.call(tf.ones(shape=input_shape)) + + # Weights should now be all clustered with the centroid 0 + assert_all_weights_associated(l.layer.kernel, centroid_index=0) + + # Set centroids so that all weights should be re-associated with centroid 1 + centroids[0].assign(mean_weight - 2.0 * max_dist) + centroids[1].assign(mean_weight) + + # Update associations of weights to centroids + l.call(tf.ones(shape=input_shape)) + + # Weights should now be all clustered with the centroid 1 + assert_all_weights_associated(l.layer.kernel, centroid_index=1) + + if __name__ == '__main__': test.main() diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py index f832e2bc0..0a467f5a3 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py @@ -64,6 +64,22 @@ def get_pulling_indices(self, weight): """ pass + @tf.custom_gradient + def add_custom_gradients(self, clst_weights, weights): + """ + This function overrides gradients in the backprop stage: original mul + becomes add, tf.sign becomes tf.identity. It is to update the original + weights with the gradients updates directly from the layer wrapped. We + assume the gradients updates on individual elements inside a cluster + will be different so that there is no point of mapping the gradient + updates back to original weight matrix using the LUT. + """ + override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6) + z = clst_weights*override_weights + def grad(dz): + return dz, dz + return z, grad + def get_clustered_weight(self, pulling_indices): """ Takes an array with integer number that represent lookup indices and forms a @@ -75,9 +91,23 @@ def get_clustered_weight(self, pulling_indices): return tf.reshape( tf.gather(self.cluster_centroids, tf.reshape(pulling_indices, shape=(-1,))), - pulling_indices.shape + shape=pulling_indices.shape ) + def get_clustered_weight_forward(self, pulling_indices, weight): + """ + Takes indices (pulling_indices) and original weights (weight) as inputs + and then forms a new array according to the given indices. The original + weights (weight) here are added to the graph since we want the backprop + to update their values via the new implementation using tf.custom_gradient + :param pulling_indices: an array of indices used for lookup. + :param weight: the original weights of the wrapped layer. + :return: array with the same shape as `pulling_indices`. Each array element + is a member of self.cluster_centroids + """ + x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,)) + return tf.reshape(self.add_custom_gradients( + x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape) class ConvolutionalWeightsCA(AbstractClusteringAlgorithm): """ diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py index 990d36f1d..e03d1a452 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry_test.py @@ -45,6 +45,48 @@ def _pull_values(self, ca, pulling_indices, expected_output): self.assertSequenceEqual(res_np_list, expected_output) + def _check_gradients(self, ca, weight, pulling_indices, expected_output): + pulling_indices_tf = tf.convert_to_tensor(pulling_indices) + weight_tf = tf.convert_to_tensor(weight) + with tf.GradientTape(persistent=True) as t: + t.watch(pulling_indices_tf) + t.watch(weight_tf) + cls_weights_tf = tf.reshape( + ca.get_clustered_weight(pulling_indices_tf), shape=(-1,)) + t.watch(cls_weights_tf) + out_forward = ca.add_custom_gradients(cls_weights_tf, weight_tf) + grad_cls_weight = t.gradient(out_forward, cls_weights_tf) + grad_weight = t.gradient(out_forward, weight_tf) + + chk_output = tf.math.equal(grad_cls_weight, grad_weight) + chk_output_np = k.batch_get_value(chk_output) + + self.assertSequenceEqual(chk_output_np, expected_output) + + @parameterized.parameters( + ([-0.800450444, 0.864694357], + [[0.220442653, 0.854694366, 0.0328432359, 0.506857157], + [0.0527950861, -0.659555554, -0.849919915, -0.54047], + [-0.305815876, 0.0865516588, 0.659202456, -0.355699599], + [-0.348868281, -0.662001, 0.6171574, -0.296582848]], + [[1, 1, 1, 1], + [1, 0, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0]], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + ) + ) + def testDenseWeightsCAGrad(self, + clustering_centroids, + weight, + pulling_indices, + expected_output): + """ + Verifies that the gradients of DenseWeightsCA work as expected. + """ + ca = clustering_registry.DenseWeightsCA(clustering_centroids) + self._check_gradients(ca, weight, pulling_indices, expected_output) + @parameterized.parameters( ([-1, 1], [[0, 0, 1], [1, 1, 1]], [[-1, -1, 1], [1, 1, 1]]), ([-1, 0, 1], [[1, 1, 1], [1, 1, 1]], [[0, 0, 0], [0, 0, 0]]), @@ -73,6 +115,29 @@ def testBiasWeightsCA(self, ca = clustering_registry.BiasWeightsCA(clustering_centroids) self._pull_values(ca, pulling_indices, expected_output) + @parameterized.parameters( + ([0.0, 3.0], + [[0.1, 0.1, 0.1], + [3.0, 3.0, 3.0], + [0.2, 0.2, 0.2]], + [[0, 0, 0], + [1, 1, 1], + [0, 0, 0]], + [1, 1, 1, 1, 1, 1, 1, 1, 1] + ) + ) + def testConvolutionalWeightsCAGrad(self, + clustering_centroids, + weight, + pulling_indices, + expected_output): + """ + Verifies that the gradients of ConvolutionalWeightsCA work as expected. + """ + ca = clustering_registry.DenseWeightsCA(clustering_centroids) + self._check_gradients(ca, weight, pulling_indices, expected_output) + + @parameterized.parameters( ([0, 3], [[[[0, 0, 0], [1, 1, 1], [0, 0, 0]]]], [[[[0, 0, 0], [3, 3, 3], [0, 0, 0]]]]),