diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index 56b3df1e9..1cf283dbb 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -74,6 +74,16 @@ py_library( ], ) +py_library( + name = "clustering_callbacks", + srcs = ["clustering_callbacks.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + # tensorflow dep1, + ], +) + py_test( name = "cluster_test", size = "medium", diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py new file mode 100644 index 000000000..17584cde8 --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py @@ -0,0 +1,97 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clustering Callbacks.""" + + +import tensorflow as tf + +from tensorflow import keras +from tensorflow_model_optimization.python.core.keras import compat + +class ClusteringSummaries(keras.callbacks.TensorBoard): + """Helper class to create tensorboard summaries for the clustering progress. + + This class is derived from tf.keras.callbacks.TensorBoard and just adds + functionality to write histograms with batch-wise frequency. + + Arguments: + log_dir: The path to the directory where the log files are saved + cluster_update_freq: determines the frequency of updates of the + clustering histograms. Same behaviour as parameter update_freq of + the base class, i.e. it accepts `'batch'`, `'epoch'` or integer. + """ + + def __init__(self, + log_dir='logs', + cluster_update_freq='epoch', + **kwargs): + super(ClusteringSummaries, self).__init__( + log_dir=log_dir, **kwargs) + + if not isinstance(log_dir, str) or not log_dir: + raise ValueError( + '`log_dir` must be a non-empty string. You passed `log_dir`=' + '{input}.'.format(input=log_dir)) + + self.cluster_update_freq = \ + 1 if cluster_update_freq == 'batch' else cluster_update_freq + + if compat.is_v1_apis(): # TF 1.X + self.writer = tf.compat.v1.summary.FileWriter(log_dir) + else: # TF 2.X + self.writer = tf.summary.create_file_writer(log_dir) + + self.continuous_batch = 0 + + def on_train_batch_begin(self, batch, logs=None): + super().on_train_batch_begin(batch, logs) + # Count batches manually to get a continuous batch count spanning + # epochs, because the function parameter 'batch' is reset to zero + # every epoch. + self.continuous_batch += 1 + + def on_train_batch_end(self, batch, logs=None): + assert self.continuous_batch >= batch, \ + "Continuous batch count must always be greater or equal than the" \ + "batch count from the parameter in the current epoch." + + super().on_train_batch_end(batch, logs) + + if self.cluster_update_freq == 'epoch': + return + elif self.continuous_batch % self.cluster_update_freq != 0: + return # skip this batch + + self._write_summary() + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + if self.cluster_update_freq == 'epoch': + self._write_summary() + + def _write_summary(self): + with self.writer.as_default(): + for layer in self.model.layers: + if not hasattr(layer, 'layer') or not hasattr(layer.layer, 'get_clusterable_weights'): + continue # skip layer + clusterable_weights = layer.layer.get_clusterable_weights() + if len(clusterable_weights) < 1: + continue # skip layers without clusterable weights + prefix = 'clustering/' + # Log variables + for var in layer.variables: + success = tf.summary.histogram( + prefix + var.name, var, step=self.continuous_batch) + assert success diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD index 33e126ff2..dd41ba082 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD @@ -22,5 +22,7 @@ py_binary( # python/keras tensorflow dep2, # python/keras/optimizer_v2 tensorflow dep2, "//tensorflow_model_optimization/python/core/clustering/keras:cluster", + "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config", + "//tensorflow_model_optimization/python/core/clustering/keras:clustering_callbacks", ], ) diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py index a9559ab89..cb4a48cee 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py @@ -13,26 +13,28 @@ # limitations under the License. # ============================================================================== # pylint: disable=missing-docstring -"""Train a simple convnet on the MNIST dataset and cluster it. - -This example is based on the sample that can be found here: -https://www.tensorflow.org/model_optimization/guide/quantization/training_example -""" +"""Train a simple convnet on the MNIST dataset.""" from __future__ import print_function +import datetime +import os + from absl import app as absl_app from absl import flags import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks keras = tf.keras +l = keras.layers FLAGS = flags.FLAGS batch_size = 128 +num_classes = 10 epochs = 12 epochs_fine_tuning = 4 @@ -41,137 +43,158 @@ 'Output directory to hold tensorboard events') -def load_mnist_dataset(): - mnist = keras.datasets.mnist - (train_images, train_labels), (test_images, test_labels) = mnist.load_data() - - # Normalize the input image so that each pixel value is between 0 to 1. - train_images = train_images / 255.0 - test_images = test_images / 255.0 - - return (train_images, train_labels), (test_images, test_labels) - - -def build_sequential_model(): - "Define the model architecture." - - return keras.Sequential([ - keras.layers.InputLayer(input_shape=(28, 28)), - keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), - keras.layers.MaxPooling2D(pool_size=(2, 2)), - keras.layers.Flatten(), - keras.layers.Dense(10) +def build_sequential_model(input_shape): + return tf.keras.Sequential([ + l.Conv2D( + 32, 5, padding='same', activation='relu', input_shape=input_shape), + l.MaxPooling2D((2, 2), (2, 2), padding='same'), + l.BatchNormalization(), + l.Conv2D(64, 5, padding='same', activation='relu'), + l.MaxPooling2D((2, 2), (2, 2), padding='same'), + l.Flatten(), + l.Dense(1024, activation='relu'), + l.Dropout(0.4), + l.Dense(num_classes, activation='softmax') ]) -def train_model(model, x_train, y_train, x_test, y_test): - model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer='adam', - metrics=['accuracy']) - - # Print the model summary. - model.summary() - - # Model needs to be clustered after initial training - # and having achieved good accuracy - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs, - verbose=1, - validation_split=0.1) - - score = model.evaluate(x_test, y_test, verbose=0) - print('Test loss:', score[0]) - print('Test accuracy:', score[1]) - - return model - - -def cluster_model(model, x_train, y_train, x_test, y_test): - print('Clustering model') - - clustering_params = { - 'number_of_clusters': 8, - 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED - } - - # Cluster model - clustered_model = cluster.cluster_weights(model, **clustering_params) - - # Use smaller learning rate for fine-tuning - # clustered model - opt = tf.keras.optimizers.Adam(learning_rate=1e-5) - - clustered_model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=opt, - metrics=['accuracy']) - - # Fine-tune clustered model - clustered_model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs_fine_tuning, - verbose=1, - validation_split=0.1) - - score = clustered_model.evaluate(x_test, y_test, verbose=0) - print('Clustered model test loss:', score[0]) - print('Clustered model test accuracy:', score[1]) - - return clustered_model - - -def test_clustered_model(clustered_model, x_test, y_test): - # Ensure accuracy persists after serializing/deserializing the model - clustered_model.save('clustered_model.h5') - # To deserialize the clustered model, use the clustering scope - with cluster.cluster_scope(): - loaded_clustered_model = keras.models.load_model('clustered_model.h5') - - # Checking that the deserialized model's accuracy matches the clustered model - score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0) - print('Deserialized model test loss:', score[0]) - print('Deserialized model test accuracy:', score[1]) - - # Ensure accuracy persists after stripping the model - stripped_model = cluster.strip_clustering(loaded_clustered_model) - stripped_model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), +def build_functional_model(input_shape): + inp = tf.keras.Input(shape=input_shape) + x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) + x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) + x = l.BatchNormalization()(x) + x = l.Conv2D(64, 5, padding='same', activation='relu')(x) + x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) + x = l.Flatten()(x) + x = l.Dense(1024, activation='relu')(x) + x = l.Dropout(0.4)(x) + out = l.Dense(num_classes, activation='softmax')(x) + + return tf.keras.models.Model([inp], [out]) + +def train_and_save(models, x_train, y_train, x_test, y_test): + for model in models: + model.compile( + loss=tf.keras.losses.categorical_crossentropy, + optimizer='adam', + metrics=['accuracy']) + + # Print the model summary. + model.summary() + + # Model needs to be clustered after initial training + # and having achieved good accuracy + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(x_test, y_test)) + score = model.evaluate(x_test, y_test, verbose=0) + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + + print('Clustering model') + + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED + } + + # Cluster model + clustered_model = cluster.cluster_weights(model, **clustering_params) + + # Use smaller learning rate for fine-tuning + # clustered model + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) + + clustered_model.compile( + loss=tf.keras.losses.categorical_crossentropy, + optimizer=opt, + metrics=['accuracy']) + + # Add callback for tensorboard summaries + log_dir = os.path.join( + FLAGS.output_dir, + datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) + callbacks = [ + clustering_callbacks.ClusteringSummaries( + log_dir, + cluster_update_freq='epoch', + update_freq='batch', + histogram_freq=1) + ] + + # Fine-tune model + clustered_model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs_fine_tuning, + verbose=1, + callbacks=callbacks, + validation_data=(x_test, y_test)) + + score = clustered_model.evaluate(x_test, y_test, verbose=0) + print('Clustered Model Test loss:', score[0]) + print('Clustered Model Test accuracy:', score[1]) + + #Ensure accuracy persists after stripping the model + stripped_model = cluster.strip_clustering(clustered_model) + + stripped_model.compile( + loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy']) + stripped_model.save('stripped_model.h5') - # Checking that the stripped model's accuracy matches the clustered model - score = stripped_model.evaluate(x_test, y_test, verbose=0) - print('Stripped model test loss:', score[0]) - print('Stripped model test accuracy:', score[1]) + # To acquire the stripped model, + # deserialize with clustering scope + with cluster.cluster_scope(): + loaded_model = keras.models.load_model('stripped_model.h5') + # Checking that the stripped model's accuracy matches the clustered model + score = loaded_model.evaluate(x_test, y_test, verbose=0) + print('Stripped Model Test loss:', score[0]) + print('Stripped Model Test accuracy:', score[1]) def main(unused_argv): if FLAGS.enable_eager: print('Running in Eager mode.') tf.compat.v1.enable_eager_execution() - # the data, shuffled and split between train and test sets - (x_train, y_train), (x_test, y_test) = load_mnist_dataset() + # input image dimensions + img_rows, img_cols = 28, 28 + # the data, shuffled and split between train and test sets + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + if tf.keras.backend.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype('float32') + x_test = x_test.astype('float32') + x_train /= 255 + x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') - # Build model - model = build_sequential_model() - # Train model - model = train_model(model, x_train, y_train, x_test, y_test) - # Cluster and fine-tune model - clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) - # Test clustered model (serialize/deserialize, strip clustering) - test_clustered_model(clustered_model, x_test, y_test) + # convert class vectors to binary class matrices + y_train = tf.keras.utils.to_categorical(y_train, num_classes) + y_test = tf.keras.utils.to_categorical(y_test, num_classes) + + sequential_model = build_sequential_model(input_shape) + functional_model = build_functional_model(input_shape) + models = [sequential_model, functional_model] + train_and_save(models, x_train, y_train, x_test, y_test) if __name__ == '__main__':