diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index 1d97c1437..1591f8928 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -137,5 +137,6 @@ py_test( deps = [ ":cluster", # tensorflow dep1, + "//tensorflow_model_optimization/python/core/keras:compat", ], -) +) \ No newline at end of file diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py index 45357c7da..743312f28 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py @@ -216,12 +216,9 @@ def _strip_clustering_wrapper(layer): # If the value was not clustered(e.g. bias), we still store a valid # reference to the tensor. We use this reference to get the value new_weight_value = k.batch_get_value([weight])[0] - layer.layer.add_weight( - name=name, - shape=new_weight_value.shape, - initializer=initializers.Constant(new_weight_value), - trainable=True - ) + setattr(layer.layer, + name, + k.variable(new_weight_value, name=name)) # When all weights are filled with the values, just return the underlying # layer since it is now fully autonomous from its wrapper return layer.layer diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index 1c9af9f29..38878badc 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -16,12 +16,14 @@ import numpy as np import tensorflow as tf - +import tempfile from absl.testing import parameterized from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.keras import compat +import os keras = tf.keras layers = keras.layers @@ -30,55 +32,105 @@ CentroidInitialization = cluster_config.CentroidInitialization class ClusterIntegrationTest(test.TestCase, parameterized.TestCase): - """Integration tests for clustering.""" - @keras_parameterized.run_all_keras_modes - def testValuesRemainClusteredAfterTraining(self): """ - Verifies that training a clustered model does not destroy the clusters. + Integration tests for clustering. """ - number_of_clusters = 10 - original_model = keras.Sequential([ - layers.Dense(2, input_shape=(2,)), - layers.Dense(2), - ]) - - clustered_model = cluster.cluster_weights( - original_model, - number_of_clusters=number_of_clusters, - cluster_centroids_init=CentroidInitialization.LINEAR - ) - - clustered_model.compile( - loss=keras.losses.categorical_crossentropy, - optimizer='adam', - metrics=['accuracy'] - ) - - def dataset_generator(): - x_train = np.array([ - [0, 1], - [2, 0], - [0, 3], - [4, 1], - [5, 1], - ]) - y_train = np.array([ - [0, 1], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - ]) - for x, y in zip(x_train, y_train): - yield np.array([x]), np.array([y]) - - clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1) - stripped_model = cluster.strip_clustering(clustered_model) - weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() - unique_weights = set(weights_as_list) - self.assertLessEqual(len(unique_weights), number_of_clusters) - - -if __name__ == '__main__': - test.main() + def setUp(self): + self.params = { + "number_of_clusters": 8, + "cluster_centroids_init": CentroidInitialization.LINEAR, + } + + self.x_train = np.array( + [[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]], + dtype="float32", + ) + + self.y_train = np.array( + [[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + dtype="float32", + ) + + def dataset_generator(self): + for x, y in zip(self.x_train, self.y_train): + yield np.array([x]), np.array([y]) + + @staticmethod + def _verify_tflite(tflite_file, x_test): + interpreter = tf.lite.Interpreter(model_path=tflite_file) + interpreter.allocate_tensors() + input_index = interpreter.get_input_details()[0]["index"] + output_index = interpreter.get_output_details()[0]["index"] + x = x_test[0] + x = x.reshape((1,) + x.shape) + interpreter.set_tensor(input_index, x) + interpreter.invoke() + interpreter.get_tensor(output_index) + + @keras_parameterized.run_all_keras_modes + def testValuesRemainClusteredAfterTraining(self): + + """ + Verifies that training a clustered model does not destroy the clusters. + """ + original_model = keras.Sequential( + [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] + ) + + clustered_model = cluster.cluster_weights(original_model, **self.params) + + clustered_model.compile( + loss=keras.losses.categorical_crossentropy, + optimizer="adam", + metrics=["accuracy"], + ) + + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) + stripped_model = cluster.strip_clustering(clustered_model) + weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() + unique_weights = set(weights_as_list) + self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def testEndToEnd(self): + + """ + Test End to End clustering. + """ + original_model = keras.Sequential( + [layers.Dense(2, input_shape=(2,)), layers.Dense(2),] + ) + + clustered_model = cluster.cluster_weights(original_model, **self.params) + + clustered_model.compile( + loss=keras.losses.categorical_crossentropy, + optimizer="adam", + metrics=["accuracy"], + ) + + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) + stripped_model = cluster.strip_clustering(clustered_model) + + _, tflite_file = tempfile.mkstemp(".tflite") + _, keras_file = tempfile.mkstemp(".h5") + + if not compat.is_v1_apis(): + converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) + else: + tf.keras.models.save_model(stripped_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + + converter.experimental_new_converter = True + tflite_model = converter.convert() + with open(tflite_file, "wb") as f: + f.write(tflite_model) + + self._verify_tflite(tflite_file, self.x_train) + + os.remove(keras_file) + os.remove(tflite_file) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 7e2f0b949..73d2c160b 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -446,6 +446,28 @@ def testClusterWeightsStrippedWeights(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(len(stripped_model.get_weights()), cluster_weight_length) + @keras_parameterized.run_all_keras_modes + def testStrippedKernel(self): + """ + Verifies that stripping the clustering wrappers from a functional model + restores the layers kernel and the layers weight array to the new clustered weight value . + """ + i1 = keras.Input(shape=(1, 1, 1)) + x1 = layers.Conv2D(1, 1)(i1) + outputs = x1 + model = keras.Model(inputs=[i1], outputs=outputs) + + clustered_model = cluster.cluster_weights(model, **self.params) + clustered_conv2d_layer = clustered_model.layers[1] + clustered_kernel = clustered_conv2d_layer.layer.kernel + stripped_model = cluster.strip_clustering(clustered_model) + stripped_conv2d_layer = stripped_model.layers[1] + + self.assertEqual(self._count_clustered_layers(stripped_model), 0) + self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel) + self.assertEqual(stripped_conv2d_layer.kernel, + stripped_conv2d_layer.weights[0]) + @keras_parameterized.run_all_keras_modes def testStripSelectivelyClusteredFunctionalModel(self): """