diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..279c66f17 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.h5 +bazel-* +__pycache__ + + diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py index e645e7c3f..b8ad81e6c 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== """Clustering API functions for Keras models.""" +import distutils.version +import tensorflow as tf from tensorflow import keras from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper @@ -24,6 +26,11 @@ Layer = keras.layers.Layer InputLayer = keras.layers.InputLayer +# From tf version 2.4.0 onwards the internal variable +# _layers has been renamed to _self_tracked_trackables. +# This variable is the only way to add cluster wrapper +# to layers of a subclassed model. +TF_VERSION_LAYERS = "2.4.0" def cluster_scope(): """Provides a scope in which Clustered layers and models can be deserialized. @@ -50,6 +57,27 @@ def cluster_scope(): } ) +def _type_model(model): + """ Auxiliary function to check type of the model: + Sequential/Functional, Layer or Subclassed. + + Args: + model : provided model to check + + Returns: + [tuple]: (is_sequential_or_functional, is_keras_layer, is_subclassed_model) + """ + is_sequential_or_functional = isinstance( + model, keras.Model) and (isinstance(model, keras.Sequential) or + model._is_graph_network) + + is_keras_layer = isinstance( + model, keras.layers.Layer) and not isinstance(model, keras.Model) + + is_subclassed_model = isinstance(model, keras.Model) and \ + not model._is_graph_network + + return (is_sequential_or_functional, is_keras_layer, is_subclassed_model) def cluster_weights(to_cluster, number_of_clusters, @@ -221,8 +249,9 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init, cluster_centroids_init)) def _add_clustering_wrapper(layer): - if isinstance(layer, keras.Model): - # Check whether the model is a subclass. + if (isinstance(layer, keras.Model)): + # Check whether the model is subclassed. + # NB: This check is copied from keras.py file in tensorflow. # There is no available public API to do this check. # pylint: disable=protected-access @@ -248,29 +277,49 @@ def _wrap_list(layers): return output - if isinstance(to_cluster, keras.Model): + (is_sequential_or_functional, is_keras_layer, is_subclassed_model) =\ + _type_model(to_cluster) + + if isinstance(to_cluster, list): + return _wrap_list(to_cluster) + elif is_sequential_or_functional: return keras.models.clone_model(to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper) - if isinstance(to_cluster, Layer): + elif is_keras_layer: return _add_clustering_wrapper(layer=to_cluster) - if isinstance(to_cluster, list): - return _wrap_list(to_cluster) - + elif is_subclassed_model: + # If the subclassed model is provided, then + # we add wrappers for all available layers and + # we wrap the whole model, so that augmented + # 'build' and 'call' functions are called. + tf_version = distutils.version.LooseVersion(tf.__version__) + layers_tf_version = distutils.version.LooseVersion(TF_VERSION_LAYERS) + for i, layer in enumerate(to_cluster.submodules): + if tf_version > layers_tf_version: + to_cluster._self_tracked_trackables[i] = _add_clustering_wrapper(layer=layer) + else: + to_cluster._layers[i] = _add_clustering_wrapper(layer=layer) + return cluster_wrapper.WrapperSubclassedModel(to_cluster) + else: + raise ValueError( + ' Clustering cannot be applied. You passed ' + 'an object of type: {input}.'.format(input=to_cluster.__class__.__name__)) -def strip_clustering(model): - """Strips clustering wrappers from the model. +def strip_clustering(to_strip): + """Strip clustering wrappers from the model. Once a model has been clustered, this method can be used - to restore the original model with the clustered weights. + to restore the original model or layer with the clustered weights. - Only sequential and functional models are supported for now. + Sequential, functional and subclassed models are supported. Arguments: - model: A `tf.keras.Model` instance with clustered layers. + to_strip: A `tf.keras.Model` instance with clustered layers or a + `tf.keras.layers.Layer` instance Returns: - A keras model with clustering wrappers removed. + A keras model or layer with clustering wrappers removed. Raises: ValueError: if the model is not a `tf.keras.Model` instance. @@ -285,9 +334,11 @@ def strip_clustering(model): ``` The exported_model and the orig_model have the same structure. """ - if not isinstance(model, keras.Model): + if not isinstance(to_strip, keras.Model) and not isinstance( + to_strip, keras.layers.Layer): raise ValueError( - 'Expected model to be a `tf.keras.Model` instance but got: ', model) + 'Expected to_strip to be a `tf.keras.Model` or \ + `tf.keras.layers.Layer` instance but got: ', to_strip) def _strip_clustering_wrapper(layer): if isinstance(layer, keras.Model): @@ -325,7 +376,30 @@ def _strip_clustering_wrapper(layer): return layer.layer return layer + (is_sequential_or_functional, is_keras_layer, is_subclassed_model) =\ + _type_model(to_strip) + # Just copy the model with the right callback - return keras.models.clone_model(model, + if is_sequential_or_functional: + return keras.models.clone_model(to_strip, input_tensors=None, clone_function=_strip_clustering_wrapper) + elif is_keras_layer: + if isinstance(to_strip, keras.layers.Layer): + return _strip_clustering_wrapper(to_strip) + elif is_subclassed_model: + to_strip_model = to_strip.model + tf_version = distutils.version.LooseVersion(tf.__version__) + layers_tf_version = distutils.version.LooseVersion(TF_VERSION_LAYERS) + if tf_version > layers_tf_version: + for i, layer in enumerate(to_strip_model._self_tracked_trackables): + to_strip_model._self_tracked_trackables[i] = _strip_clustering_wrapper(layer=layer) + else: + for i, layer in enumerate(to_strip_model._layers): + to_strip_model._layers[i] = _strip_clustering_wrapper(layer=layer) + return to_strip_model + else: + raise ValueError( + ' Strip clustering cannot be applied. You passed ' + 'an object of type: {input}.'.format(input=to_strip.__class__.__name__)) + diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index a133bd687..87539a31c 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -24,14 +24,42 @@ from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config + +from tensorflow_model_optimization.python.core.keras import compat +from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper + from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster + keras = tf.keras layers = keras.layers test = tf.test CentroidInitialization = cluster_config.CentroidInitialization +class SubclassedModel(keras.Model): + """Subclassed model with one layer.""" + + def __init__(self): + """Subclassed model with one dense layer.""" + super(SubclassedModel, self).__init__(name='subclass_model') + self.dense_layer = keras.layers.Dense(5, activation='relu') + + def call(self, inputs): + return self.dense_layer(inputs) + +class SubclassedModelTwoLayers(keras.Model): + """Subclassed model with two layers.""" + + def __init__(self): + """Subclassed model with two layers.""" + super(SubclassedModelTwoLayers, self).__init__(name='subclass_model') + self.dense_layer1 = keras.layers.Dense(5, activation='relu') + self.dense_layer2 = keras.layers.Dense(5, activation='softmax') + + def call(self, inputs): + x = self.dense_layer1(inputs) + return self.dense_layer2(x) class ClusterIntegrationTest(test.TestCase, parameterized.TestCase): """Integration tests for clustering.""" @@ -233,6 +261,105 @@ def clusters_check(stripped_model): self.end_to_end_testing(original_model, clusters_check) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def testEndToEndSubclassedModel(self): + """Test End to End clustering for the subclassed model. + In this test we pass the whole subclassed model for clustering. + We check that the number of weights is less the requested + number of clusters after stripping clustering wrapper. + + """ + subclassed_model = SubclassedModel() + + clustered_model = cluster.cluster_weights(subclassed_model, **self.params) + + clustered_model.compile( + loss=keras.losses.categorical_crossentropy, + optimizer="adam", + metrics=["accuracy"] + ) + + # The model should be trained a little bit. + clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) + stripped_model = cluster.strip_clustering(clustered_model) + + nr_unique_weights = len(np.unique(stripped_model.layers[0].\ + trainable_weights[0].numpy().flatten())) + self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def testEndToEndSubclassedModelTwoLayers(self): + """Test End to End clustering for the subclass model. + + This test demonstrates another approach. + All layers that are present in the subclassed model + (see SubclassedModelTwoLayers definition above) are wrapped + manually. The model should be re-build in this case. + + We need to strip clustering away manually as well (see how it is + done inside the test). + + Clustering is working well and clusters are updated during + training.""" + subclassed_model = SubclassedModelTwoLayers() + input_shape = (1, 5) + + # We need to build the model + subclassed_model.build(input_shape=input_shape) + + # Check that the number of weights is bigger than the number of clusters. + nr_unique_weights = len(np.unique(subclassed_model.layers[0].\ + trainable_weights[0].numpy().flatten())) + self.assertGreater(nr_unique_weights, self.params["number_of_clusters"]) + nr_unique_weights = len(np.unique(subclassed_model.layers[1].\ + trainable_weights[0].numpy().flatten())) + self.assertGreater(nr_unique_weights, self.params["number_of_clusters"]) + + # Now we apply cluster_weights for each layer. + subclassed_model.dense_layer1 = cluster.cluster_weights( + subclassed_model.dense_layer1, **self.params) + subclassed_model.dense_layer2 = cluster.cluster_weights( + subclassed_model.dense_layer2, **self.params) + + # We need to re-build the model again. + subclassed_model.build(input_shape=input_shape) + + subclassed_model.compile( + loss=keras.losses.categorical_crossentropy, + optimizer="adam", + metrics=["accuracy"] + ) + + subclassed_model.fit(x=self.dataset_generator(), steps_per_epoch=1) + + # We strip from layers that were wrapped. + subclassed_model.dense_layer1 = cluster.strip_clustering(subclassed_model.dense_layer1) + subclassed_model.dense_layer2 = cluster.strip_clustering(subclassed_model.dense_layer2) + + # Checks that the number of unique values is less than the requested + # number of clusters. + nr_unique_weights = len(np.unique(subclassed_model.layers[0].\ + trainable_weights[0].numpy().flatten())) + self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) + nr_unique_weights = len(np.unique(subclassed_model.layers[1].\ + trainable_weights[0].numpy().flatten())) + self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def testEndToEndSubclassedModelAsDeepLayer(self): + """Test End to End clustering for the model with the layer as a subclass model.""" + # This case is not supported currently. + + internal_model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(5,))]) + subclassed_model = SubclassedModel() + original_model = keras.Sequential([ + internal_model, + subclassed_model, + ]) + + with self.assertRaisesRegexp(ValueError, "Subclassed models.*"): + self.end_to_end_testing(original_model) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def testEndToEndDeepLayer(self): """Test End to End clustering for the model with deep layer.""" @@ -302,3 +429,4 @@ def clusters_check(stripped_model): if __name__ == "__main__": test.main() + diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 771f4dee6..6c7a48141 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -33,12 +33,12 @@ test = tf.test -class TestModel(keras.Model): - """A model subclass.""" +class SubclassedModel(keras.Model): + """A subclassed model.""" def __init__(self): """A test subclass model with one dense layer.""" - super(TestModel, self).__init__(name='test_model') + super(SubclassedModel, self).__init__(name='test_model') self.layer1 = keras.layers.Dense(10, activation='relu') def call(self, inputs): @@ -418,6 +418,48 @@ def testClusterFunctionalModel(self): outputs = layers.Add()([x1, x2]) model = keras.Model(inputs=[i1, i2], outputs=outputs) clustered_model = cluster.cluster_weights(model, **self.params) + # layer Add does not have trainable weights + self.assertEqual(self._count_clustered_layers(clustered_model), 2) + + @keras_parameterized.run_all_keras_modes + def testClusterFunctionalModelWithLayerReused(self): + """ + Verifies that a layer reused within a functional model multiple times is + only being clustered once. + """ + # The model reuses the Dense() layer. Make sure it's only clustered once. + inp = keras.Input(shape=(10,)) + dense_layer = layers.Dense(10) + x = dense_layer(inp) + x = dense_layer(x) + model = keras.Model(inputs=[inp], outputs=[x]) + clustered_model = cluster.cluster_weights(model, **self.params) + self.assertEqual(self._count_clustered_layers(clustered_model), 1) + + @keras_parameterized.run_all_keras_modes + def testClusterConfigAcceptsStrParameters(self): + """ + Verifies that cluster_config enum accepts enum set as a string. + We need this, for example, when we use keras-tuner with + clustering. + """ + wrapped_layer = cluster.cluster_weights(self.keras_clusterable_layer, + **self.params_str) + + self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer) + + @keras_parameterized.run_all_keras_modes + def testClusterFunctionalModel(self): + """ + Verifies that a functional model is being clustered correctly. + """ + i1 = keras.Input(shape=(10,)) + i2 = keras.Input(shape=(10,)) + x1 = layers.Dense(10)(i1) + x2 = layers.Dense(10)(i2) + outputs = layers.Add()([x1, x2]) + model = keras.Model(inputs=[i1, i2], outputs=outputs) + clustered_model = cluster.cluster_weights(model, **self.params) self.assertEqual(self._count_clustered_layers(clustered_model), 3) @keras_parameterized.run_all_keras_modes @@ -433,18 +475,30 @@ def testClusterFunctionalModelWithLayerReused(self): self.assertEqual(self._count_clustered_layers(clustered_model), 1) @keras_parameterized.run_all_keras_modes - def testClusterSubclassModel(self): - """Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception.""" - model = TestModel() - with self.assertRaises(ValueError): - _ = cluster.cluster_weights(model, **self.params) + def testClusterSubclassedModel(self): + """ + Verifies clustering of a subclassed model. + """ + model = SubclassedModel() + + clustered_model = cluster.cluster_weights(model, **self.params) + self.assertEqual(self._count_clustered_layers(model), 1) + + stripped_model = cluster.strip_clustering(clustered_model) + self.assertEqual(self._count_clustered_layers(stripped_model), 0) @keras_parameterized.run_all_keras_modes def testClusterSubclassModelAsSubmodel(self): - """Verifies that attempting to cluster a model with submodel that is a subclass throws an exception.""" - model_subclass = TestModel() - model = keras.Sequential([layers.Dense(10), model_subclass]) - with self.assertRaisesRegex(ValueError, 'Subclassed models.*'): + """ + Verifies that attempting to cluster a model with submodel + that is a subclass throws an exception. + """ + model_subclass = SubclassedModel() + model = keras.Sequential([ + layers.Dense(10), + model_subclass + ]) + with self.assertRaisesRegexp(ValueError, "Subclassed models.*"): _ = cluster.cluster_weights(model, **self.params) @keras_parameterized.run_all_keras_modes @@ -460,7 +514,7 @@ def testStripClusteringSequentialModel(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config()) - + @keras_parameterized.run_all_keras_modes def testClusterStrippingFunctionalModel(self): """Verifies that stripping the clustering wrappers from a functional model produces the expected config.""" diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index c9c47e824..88880c6f9 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -379,3 +379,33 @@ def get_weights(self): def set_weights(self, weights): self.layer.set_weights(weights) + +class WrapperSubclassedModel(keras.Model): + """This wrapper wraps a keras subclassed model so that the weight tensor(s) + in keras layers that are defined in this model can be clustered. + """ + def __init__(self, model): + super(WrapperSubclassedModel, self).__init__() + + # This wrapper is needed only for subclassed models. + is_subclassed_model = isinstance(model, keras.Model) and \ + not model._is_graph_network + if not is_subclassed_model: + raise ValueError( + "The provided model should be subclassed. The provided: {}".format( + model.__class__ + ) + ) + self.model = model + + def build(self, input_shape): + for layer in self.model.layers: + if isinstance(layer, ClusterWeights): + layer.build(input_shape = input_shape) + return self.model.build(input_shape = input_shape) + + def call(self, inputs): + for layer in self.model.layers: + if isinstance(layer, ClusterWeights): + layer.call(inputs) + return self.model.call(inputs) \ No newline at end of file diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py index 591c8f459..dea784efa 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py @@ -49,6 +49,16 @@ class AlreadyClusterableLayer(layers.Dense, clusterable_layer.ClusterableLayer): def get_clusterable_weights(self): pass +class SubclassedModel(keras.Model): + """A subclassed model.""" + + def __init__(self): + """A test subclass model with one dense layer.""" + super(SubclassedModel, self).__init__(name='test_model') + self.dense_layer = keras.layers.Dense(32, activation='relu') + + def call(self, inputs): + return self.dense_layer(inputs) class ClusterWeightsTest(test.TestCase, parameterized.TestCase): """Unit tests for the cluster_wrapper module.""" @@ -173,6 +183,27 @@ def testValuesAreClusteredAfterStripping(self, # Make sure that the stripped layer is the Dense one self.assertIsInstance(stripped_model.layers[0], layers.Dense) + def testClusterWrappersAreStrippedInSubclassedModel(self): + """ + Verifies that for a subclassed model all ClusterWeights + wrappers are stripped from the model. + """ + original_model = SubclassedModel() + + clustered_model = cluster.cluster_weights( + original_model, + number_of_clusters=8, + cluster_centroids_init=CentroidInitialization.DENSITY_BASED + ) + + self.assertIsInstance(clustered_model, cluster_wrapper.WrapperSubclassedModel) + + stripped_model = cluster.strip_clustering(clustered_model) + + # Make sure that the stripped layer is the Dense one + self.assertIsInstance(stripped_model.layers[0], layers.Dense) + self.assertIsInstance(stripped_model.dense_layer, layers.Dense) + def testClusterReassociation(self): """Verifies that the association of weights to cluster centroids are updated every iteration.""" diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_subclassed.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_subclassed.py new file mode 100644 index 000000000..c43f5bf9b --- /dev/null +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_subclassed.py @@ -0,0 +1,194 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=missing-docstring +"""Train a simple convnet that is written as a keras subclassed model +on the MNIST dataset and cluster it. + +This example is based on the sample that can be found here: +https://www.tensorflow.org/tutorials/quickstart/advanced +""" + +from __future__ import print_function + +import tensorflow as tf +import os + +from tensorflow.keras.layers import Dense, Flatten, Conv2D +from tensorflow.keras import Model + +from tensorflow_model_optimization.python.core.clustering.keras import cluster +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks + +BATCH_SIZE = 32 +EPOCHS = 5 +EPOCHS_FINE_TUNING = 4 + +# Load and prepare MNIST dataset +mnist = tf.keras.datasets.mnist + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train, x_test = x_train / 255.0, x_test / 255.0 + +# Add a channels dimension +x_train = x_train[..., tf.newaxis].astype("float32") +x_test = x_test[..., tf.newaxis].astype("float32") + +# Use tf.data to batch and shuffle the dataset. +train_ds = tf.data.Dataset.from_tensor_slices( + (x_train, y_train)).shuffle(10000).batch(BATCH_SIZE) + +test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE) + +# Build the model using the Keras model subclassing API. +class MyModel(Model): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = Conv2D(32, 3, activation='relu') + self.flatten = Flatten() + self.d1 = Dense(128, activation='relu') + self.d2 = Dense(10) + + def call(self, x): + x = self.conv1(x) + x = self.flatten(x) + x = self.d1(x) + return self.d2(x) + +# Create an instance of the model +model = MyModel() + +# Choose an optimizer and loss function for training. +loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +optimizer = tf.keras.optimizers.Adam() + +# Select metrics to measure the loss and the accuracy of the model. +# These metrics accumulate the values over epochs and then print the overall result. +train_loss = tf.keras.metrics.Mean(name='train_loss') +train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') + +test_loss = tf.keras.metrics.Mean(name='test_loss') +test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') + +# Use tf.GradientTape to train the model as it is done in the tutorial. +@tf.function +def train_step(images, labels): + with tf.GradientTape() as tape: + # training=True is only needed if there are layers with different + # behavior during training versus inference (e.g. Dropout). + predictions = model(images, training=True) + loss = loss_object(labels, predictions) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + train_loss(loss) + train_accuracy(labels, predictions) + +# Test the model. +@tf.function +def test_step(images, labels): + # training=False is only needed if there are layers with different + # behavior during training versus inference (e.g. Dropout). + predictions = model(images, training=False) + t_loss = loss_object(labels, predictions) + + test_loss(t_loss) + test_accuracy(labels, predictions) + +for epoch in range(EPOCHS): + # Reset the metrics at the start of the next epoch + train_loss.reset_states() + train_accuracy.reset_states() + test_loss.reset_states() + test_accuracy.reset_states() + + for images, labels in train_ds: + train_step(images, labels) + + for test_images, test_labels in test_ds: + test_step(test_images, test_labels) + + print( + f'Epoch {epoch + 1}, ' + f'Loss: {train_loss.result()}, ' + f'Accuracy: {train_accuracy.result()}, ' + f'Test Loss: {test_loss.result()}, ' + f'Test Accuracy: {test_accuracy.result()}' + ) + +def cluster_model(model, x_train, y_train, x_test, y_test): + print('Clustering model') + + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED + } + + # Cluster model + clustered_model = cluster.cluster_weights(model, **clustering_params) + + # Use smaller learning rate for fine-tuning + # clustered model + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) + + clustered_model.compile( + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer = opt, + metrics = ['accuracy']) + + # Fine-tune clustered model + clustered_model.fit( + x_train, + y_train, + batch_size = BATCH_SIZE, + epochs = EPOCHS_FINE_TUNING, + verbose = 1, + validation_split = 0.1) + + score = clustered_model.evaluate(x_test, y_test, verbose=0) + print('Clustered model test loss:', score[0]) + print('Clustered model test accuracy:', score[1]) + + return clustered_model + +def test_clustered_model(clustered_model, x_test, y_test): + # Stripping the model + stripped_model = cluster.strip_clustering(clustered_model) + stripped_model.compile( + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer = 'adam', + metrics = ['accuracy']) + + # Checking that the stripped model's accuracy matches the clustered model + score = stripped_model.evaluate(x_test, y_test, verbose=0) + + print('Stripped model test loss:', score[0]) + print('Stripped model test accuracy:', score[1]) + + # Checking that we have the number of weights less than the + # number of clusters. + for layer in stripped_model.layers: + nr_unique_weights = len(set(layer.get_weights()[0].flatten())) \ + if len(layer.get_weights()) > 0 else 0 + print("Layer name: {}, number of clusters: {}".format( + layer.name, nr_unique_weights + )) + +# Cluster and fine-tune model +clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) + +# Test clustered model (strip clustering) +test_clustered_model(clustered_model, x_test, y_test)