diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py index cb4a48cee..4da722635 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================== # pylint: disable=missing-docstring -"""Train a simple convnet on the MNIST dataset.""" +"""Train a simple convnet on the MNIST dataset and cluster it. -from __future__ import print_function +This example is based on the sample that can be found here: +https://www.tensorflow.org/model_optimization/guide/quantization/training_example +""" +from __future__ import print_function import datetime import os @@ -29,12 +32,10 @@ from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks keras = tf.keras -l = keras.layers FLAGS = flags.FLAGS batch_size = 128 -num_classes = 10 epochs = 12 epochs_fine_tuning = 4 @@ -43,158 +44,149 @@ 'Output directory to hold tensorboard events') -def build_sequential_model(input_shape): - return tf.keras.Sequential([ - l.Conv2D( - 32, 5, padding='same', activation='relu', input_shape=input_shape), - l.MaxPooling2D((2, 2), (2, 2), padding='same'), - l.BatchNormalization(), - l.Conv2D(64, 5, padding='same', activation='relu'), - l.MaxPooling2D((2, 2), (2, 2), padding='same'), - l.Flatten(), - l.Dense(1024, activation='relu'), - l.Dropout(0.4), - l.Dense(num_classes, activation='softmax') - ]) +def load_mnist_dataset(): + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 -def build_functional_model(input_shape): - inp = tf.keras.Input(shape=input_shape) - x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) - x = l.BatchNormalization()(x) - x = l.Conv2D(64, 5, padding='same', activation='relu')(x) - x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) - x = l.Flatten()(x) - x = l.Dense(1024, activation='relu')(x) - x = l.Dropout(0.4)(x) - out = l.Dense(num_classes, activation='softmax')(x) - - return tf.keras.models.Model([inp], [out]) - -def train_and_save(models, x_train, y_train, x_test, y_test): - for model in models: - model.compile( - loss=tf.keras.losses.categorical_crossentropy, - optimizer='adam', - metrics=['accuracy']) - - # Print the model summary. - model.summary() - - # Model needs to be clustered after initial training - # and having achieved good accuracy - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs, - verbose=1, - validation_data=(x_test, y_test)) - score = model.evaluate(x_test, y_test, verbose=0) - print('Test loss:', score[0]) - print('Test accuracy:', score[1]) - - print('Clustering model') - - clustering_params = { - 'number_of_clusters': 8, - 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED - } - - # Cluster model - clustered_model = cluster.cluster_weights(model, **clustering_params) - - # Use smaller learning rate for fine-tuning - # clustered model - opt = tf.keras.optimizers.Adam(learning_rate=1e-5) - - clustered_model.compile( - loss=tf.keras.losses.categorical_crossentropy, - optimizer=opt, - metrics=['accuracy']) + return (train_images, train_labels), (test_images, test_labels) + +def build_sequential_model(): + "Define the model architecture." - # Add callback for tensorboard summaries - log_dir = os.path.join( - FLAGS.output_dir, - datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) - callbacks = [ - clustering_callbacks.ClusteringSummaries( - log_dir, - cluster_update_freq='epoch', - update_freq='batch', - histogram_freq=1) - ] - - # Fine-tune model - clustered_model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs_fine_tuning, - verbose=1, - callbacks=callbacks, - validation_data=(x_test, y_test)) - - score = clustered_model.evaluate(x_test, y_test, verbose=0) - print('Clustered Model Test loss:', score[0]) - print('Clustered Model Test accuracy:', score[1]) - - #Ensure accuracy persists after stripping the model - stripped_model = cluster.strip_clustering(clustered_model) - - stripped_model.compile( - loss=tf.keras.losses.categorical_crossentropy, + return keras.Sequential([ + keras.layers.InputLayer(input_shape=(28, 28)), + keras.layers.Reshape(target_shape=(28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense(10) + ]) + + +def train_model(model, x_train, y_train, x_test, y_test): + model.compile( + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer='adam', + metrics=['accuracy']) + + # Print the model summary. + model.summary() + + # Model needs to be clustered after initial training + # and having achieved good accuracy + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_split=0.1) + + score = model.evaluate(x_test, y_test, verbose=0) + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + + return model + + +def cluster_model(model, x_train, y_train, x_test, y_test): + print('Clustering model') + + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED + } + + # Cluster model + clustered_model = cluster.cluster_weights(model, **clustering_params) + + # Use smaller learning rate for fine-tuning + # clustered model + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) + + clustered_model.compile( + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=opt, + metrics=['accuracy']) + + # Add callback for tensorboard summaries + log_dir = os.path.join( + FLAGS.output_dir, + datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) + callbacks = [ + clustering_callbacks.ClusteringSummaries( + log_dir, + cluster_update_freq='epoch', + update_freq='batch', + histogram_freq=1) + ] + + # Fine-tune clustered model + clustered_model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs_fine_tuning, + verbose=1, + callbacks=callbacks, + validation_split=0.1) + + score = clustered_model.evaluate(x_test, y_test, verbose=0) + print('Clustered model test loss:', score[0]) + print('Clustered model test accuracy:', score[1]) + + return clustered_model + + +def test_clustered_model(clustered_model, x_test, y_test): + # Ensure accuracy persists after serializing/deserializing the model + clustered_model.save('clustered_model.h5') + # To deserialize the clustered model, use the clustering scope + with cluster.cluster_scope(): + loaded_clustered_model = keras.models.load_model('clustered_model.h5') + + # Checking that the deserialized model's accuracy matches the clustered model + score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0) + print('Deserialized model test loss:', score[0]) + print('Deserialized model test accuracy:', score[1]) + + # Ensure accuracy persists after stripping the model + stripped_model = cluster.strip_clustering(loaded_clustered_model) + stripped_model.compile( + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=['accuracy']) - stripped_model.save('stripped_model.h5') - # To acquire the stripped model, - # deserialize with clustering scope - with cluster.cluster_scope(): - loaded_model = keras.models.load_model('stripped_model.h5') + # Checking that the stripped model's accuracy matches the clustered model + score = stripped_model.evaluate(x_test, y_test, verbose=0) + print('Stripped model test loss:', score[0]) + print('Stripped model test accuracy:', score[1]) - # Checking that the stripped model's accuracy matches the clustered model - score = loaded_model.evaluate(x_test, y_test, verbose=0) - print('Stripped Model Test loss:', score[0]) - print('Stripped Model Test accuracy:', score[1]) def main(unused_argv): if FLAGS.enable_eager: print('Running in Eager mode.') tf.compat.v1.enable_eager_execution() - # input image dimensions - img_rows, img_cols = 28, 28 - # the data, shuffled and split between train and test sets - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - - if tf.keras.backend.image_data_format() == 'channels_first': - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) - input_shape = (1, img_rows, img_cols) - else: - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) - input_shape = (img_rows, img_cols, 1) - - x_train = x_train.astype('float32') - x_test = x_test.astype('float32') - x_train /= 255 - x_test /= 255 + (x_train, y_train), (x_test, y_test) = load_mnist_dataset() + print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') - # convert class vectors to binary class matrices - y_train = tf.keras.utils.to_categorical(y_train, num_classes) - y_test = tf.keras.utils.to_categorical(y_test, num_classes) - - sequential_model = build_sequential_model(input_shape) - functional_model = build_functional_model(input_shape) - models = [sequential_model, functional_model] - train_and_save(models, x_train, y_train, x_test, y_test) + # Build model + model = build_sequential_model() + # Train model + model = train_model(model, x_train, y_train, x_test, y_test) + # Cluster and fine-tune model + clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) + # Test clustered model (serialize/deserialize, strip clustering) + test_clustered_model(clustered_model, x_test, y_test) if __name__ == '__main__':