diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index 56b3df1e9..1cf283dbb 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -74,6 +74,16 @@ py_library( ], ) +py_library( + name = "clustering_callbacks", + srcs = ["clustering_callbacks.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + # tensorflow dep1, + ], +) + py_test( name = "cluster_test", size = "medium", diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py new file mode 100644 index 000000000..09f742a9d --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py @@ -0,0 +1,81 @@ + +import tensorflow as tf + +from tensorflow import keras +from tensorflow_model_optimization.python.core.keras import compat + +class ClusteringSummaries(keras.callbacks.TensorBoard): + """Helper class to create tensorboard summaries for the clustering progress. + + This class is derived from tf.keras.callbacks.TensorBoard and just adds + functionality to write histograms with batch-wise frequency. + + Arguments: + log_dir: The path to the directory where the log files are saved + cluster_update_freq: determines the frequency of updates of the + clustering histograms. Same behaviour as parameter update_freq of + the base class, i.e. it accepts `'batch'`, `'epoch'` or integer. + """ + + def __init__(self, + log_dir='logs', + cluster_update_freq='epoch', + **kwargs): + super(ClusteringSummaries, self).__init__( + log_dir=log_dir, **kwargs) + + if not isinstance(log_dir, str) or not log_dir: + raise ValueError( + '`log_dir` must be a non-empty string. You passed `log_dir`=' + '{input}.'.format(input=log_dir)) + + self.cluster_update_freq = \ + 1 if cluster_update_freq == 'batch' else cluster_update_freq + + if compat.is_v1_apis(): # TF 1.X + self.writer = tf.compat.v1.summary.FileWriter(log_dir) + else: # TF 2.X + self.writer = tf.summary.create_file_writer(log_dir) + + self.continuous_batch = 0 + + def on_train_batch_begin(self, batch, logs=None): + super().on_train_batch_begin(batch, logs) + # Count batches manually to get a continuous batch count spanning + # epochs, because the function parameter 'batch' is reset to zero + # every epoch. + self.continuous_batch += 1 + + def on_train_batch_end(self, batch, logs=None): + assert self.continuous_batch >= batch, \ + "Continuous batch count must always be greater or equal than the" \ + "batch count from the parameter in the current epoch." + + super().on_train_batch_end(batch, logs) + + if self.cluster_update_freq == 'epoch': + return + elif self.continuous_batch % self.cluster_update_freq != 0: + return # skip this batch + + self._write_summary() + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + if self.cluster_update_freq == 'epoch': + self._write_summary() + + def _write_summary(self): + with self.writer.as_default(): + for layer in self.model.layers: + if not hasattr(layer, 'layer') or not hasattr(layer.layer, 'get_clusterable_weights'): + continue # skip layer + clusterable_weights = layer.layer.get_clusterable_weights() + if len(clusterable_weights) < 1: + continue # skip layers without clusterable weights + prefix = 'clustering/' + # Log variables + for var in layer.variables: + success = tf.summary.histogram( + prefix + var.name, var, step=self.continuous_batch) + assert success diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD index 33e126ff2..dd41ba082 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD @@ -22,5 +22,7 @@ py_binary( # python/keras tensorflow dep2, # python/keras/optimizer_v2 tensorflow dep2, "//tensorflow_model_optimization/python/core/clustering/keras:cluster", + "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config", + "//tensorflow_model_optimization/python/core/clustering/keras:clustering_callbacks", ], ) diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py index a9559ab89..22eedca48 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py @@ -20,6 +20,8 @@ """ from __future__ import print_function +import datetime +import os from absl import app as absl_app from absl import flags @@ -27,6 +29,7 @@ import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks keras = tf.keras @@ -51,7 +54,6 @@ def load_mnist_dataset(): return (train_images, train_labels), (test_images, test_labels) - def build_sequential_model(): "Define the model architecture." @@ -111,6 +113,18 @@ def cluster_model(model, x_train, y_train, x_test, y_test): optimizer=opt, metrics=['accuracy']) + # Add callback for tensorboard summaries + log_dir = os.path.join( + FLAGS.output_dir, + datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) + callbacks = [ + clustering_callbacks.ClusteringSummaries( + log_dir, + cluster_update_freq='epoch', + update_freq='batch', + histogram_freq=1) + ] + # Fine-tune clustered model clustered_model.fit( x_train, @@ -118,6 +132,7 @@ def cluster_model(model, x_train, y_train, x_test, y_test): batch_size=batch_size, epochs=epochs_fine_tuning, verbose=1, + callbacks=callbacks, validation_split=0.1) score = clustered_model.evaluate(x_test, y_test, verbose=0)