Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tensorflow_model_optimization/python/core/clustering/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ py_library(
],
)

py_library(
name = "clustering_callbacks",
srcs = ["clustering_callbacks.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
# tensorflow dep1,
],
)

py_test(
name = "cluster_test",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@

import tensorflow as tf

from tensorflow import keras
from tensorflow_model_optimization.python.core.keras import compat

class ClusteringSummaries(keras.callbacks.TensorBoard):
"""Helper class to create tensorboard summaries for the clustering progress.

This class is derived from tf.keras.callbacks.TensorBoard and just adds
functionality to write histograms with batch-wise frequency.

Arguments:
log_dir: The path to the directory where the log files are saved
cluster_update_freq: determines the frequency of updates of the
clustering histograms. Same behaviour as parameter update_freq of
the base class, i.e. it accepts `'batch'`, `'epoch'` or integer.
"""

def __init__(self,
log_dir='logs',
cluster_update_freq='epoch',
**kwargs):
super(ClusteringSummaries, self).__init__(
log_dir=log_dir, **kwargs)

if not isinstance(log_dir, str) or not log_dir:
raise ValueError(
'`log_dir` must be a non-empty string. You passed `log_dir`='
'{input}.'.format(input=log_dir))

self.cluster_update_freq = \
1 if cluster_update_freq == 'batch' else cluster_update_freq

if compat.is_v1_apis(): # TF 1.X
self.writer = tf.compat.v1.summary.FileWriter(log_dir)
else: # TF 2.X
self.writer = tf.summary.create_file_writer(log_dir)

self.continuous_batch = 0

def on_train_batch_begin(self, batch, logs=None):
super().on_train_batch_begin(batch, logs)
# Count batches manually to get a continuous batch count spanning
# epochs, because the function parameter 'batch' is reset to zero
# every epoch.
self.continuous_batch += 1

def on_train_batch_end(self, batch, logs=None):
assert self.continuous_batch >= batch, \
"Continuous batch count must always be greater or equal than the" \
"batch count from the parameter in the current epoch."

super().on_train_batch_end(batch, logs)

if self.cluster_update_freq == 'epoch':
return
elif self.continuous_batch % self.cluster_update_freq != 0:
return # skip this batch

self._write_summary()

def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
if self.cluster_update_freq == 'epoch':
self._write_summary()

def _write_summary(self):
with self.writer.as_default():
for layer in self.model.layers:
if not hasattr(layer, 'layer') or not hasattr(layer.layer, 'get_clusterable_weights'):
continue # skip layer
clusterable_weights = layer.layer.get_clusterable_weights()
if len(clusterable_weights) < 1:
continue # skip layers without clusterable weights
prefix = 'clustering/'
# Log variables
for var in layer.variables:
success = tf.summary.histogram(
prefix + var.name, var, step=self.continuous_batch)
assert success
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ py_binary(
# python/keras tensorflow dep2,
# python/keras/optimizer_v2 tensorflow dep2,
"//tensorflow_model_optimization/python/core/clustering/keras:cluster",
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_callbacks",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
"""

from __future__ import print_function
import datetime
import os

from absl import app as absl_app
from absl import flags

import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks

keras = tf.keras

Expand All @@ -51,7 +54,6 @@ def load_mnist_dataset():

return (train_images, train_labels), (test_images, test_labels)


def build_sequential_model():
"Define the model architecture."

Expand Down Expand Up @@ -111,13 +113,26 @@ def cluster_model(model, x_train, y_train, x_test, y_test):
optimizer=opt,
metrics=['accuracy'])

# Add callback for tensorboard summaries
log_dir = os.path.join(
FLAGS.output_dir,
datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering"))
callbacks = [
clustering_callbacks.ClusteringSummaries(
log_dir,
cluster_update_freq='epoch',
update_freq='batch',
histogram_freq=1)
]

# Fine-tune clustered model
clustered_model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs_fine_tuning,
verbose=1,
callbacks=callbacks,
validation_split=0.1)

score = clustered_model.evaluate(x_test, y_test, verbose=0)
Expand Down