Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=missing-docstring
"""Train a simple convnet on the MNIST dataset."""
"""Train a simple convnet on the MNIST dataset and cluster it.
from __future__ import print_function
This example is based on the sample that can be found here:
https://www.tensorflow.org/model_optimization/guide/quantization/training_example
"""

from __future__ import print_function
import datetime
import os

Expand All @@ -29,12 +32,10 @@
from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks

keras = tf.keras
l = keras.layers

FLAGS = flags.FLAGS

batch_size = 128
num_classes = 10
epochs = 12
epochs_fine_tuning = 4

Expand All @@ -43,158 +44,149 @@
'Output directory to hold tensorboard events')


def build_sequential_model(input_shape):
return tf.keras.Sequential([
l.Conv2D(
32, 5, padding='same', activation='relu', input_shape=input_shape),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.BatchNormalization(),
l.Conv2D(64, 5, padding='same', activation='relu'),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
l.Dense(1024, activation='relu'),
l.Dropout(0.4),
l.Dense(num_classes, activation='softmax')
])
def load_mnist_dataset():
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

def build_functional_model(input_shape):
inp = tf.keras.Input(shape=input_shape)
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
x = l.BatchNormalization()(x)
x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
x = l.Flatten()(x)
x = l.Dense(1024, activation='relu')(x)
x = l.Dropout(0.4)(x)
out = l.Dense(num_classes, activation='softmax')(x)

return tf.keras.models.Model([inp], [out])

def train_and_save(models, x_train, y_train, x_test, y_test):
for model in models:
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])

# Print the model summary.
model.summary()

# Model needs to be clustered after initial training
# and having achieved good accuracy
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

print('Clustering model')

clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
}

# Cluster model
clustered_model = cluster.cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
# clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer=opt,
metrics=['accuracy'])
return (train_images, train_labels), (test_images, test_labels)

def build_sequential_model():
"Define the model architecture."

# Add callback for tensorboard summaries
log_dir = os.path.join(
FLAGS.output_dir,
datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering"))
callbacks = [
clustering_callbacks.ClusteringSummaries(
log_dir,
cluster_update_freq='epoch',
update_freq='batch',
histogram_freq=1)
]

# Fine-tune model
clustered_model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs_fine_tuning,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))

score = clustered_model.evaluate(x_test, y_test, verbose=0)
print('Clustered Model Test loss:', score[0])
print('Clustered Model Test accuracy:', score[1])

#Ensure accuracy persists after stripping the model
stripped_model = cluster.strip_clustering(clustered_model)

stripped_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
return keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])


def train_model(model, x_train, y_train, x_test, y_test):
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy'])

# Print the model summary.
model.summary()

# Model needs to be clustered after initial training
# and having achieved good accuracy
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

return model


def cluster_model(model, x_train, y_train, x_test, y_test):
print('Clustering model')

clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
}

# Cluster model
clustered_model = cluster.cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
# clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])

# Add callback for tensorboard summaries
log_dir = os.path.join(
FLAGS.output_dir,
datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering"))
callbacks = [
clustering_callbacks.ClusteringSummaries(
log_dir,
cluster_update_freq='epoch',
update_freq='batch',
histogram_freq=1)
]

# Fine-tune clustered model
clustered_model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs_fine_tuning,
verbose=1,
callbacks=callbacks,
validation_split=0.1)

score = clustered_model.evaluate(x_test, y_test, verbose=0)
print('Clustered model test loss:', score[0])
print('Clustered model test accuracy:', score[1])

return clustered_model


def test_clustered_model(clustered_model, x_test, y_test):
# Ensure accuracy persists after serializing/deserializing the model
clustered_model.save('clustered_model.h5')
# To deserialize the clustered model, use the clustering scope
with cluster.cluster_scope():
loaded_clustered_model = keras.models.load_model('clustered_model.h5')

# Checking that the deserialized model's accuracy matches the clustered model
score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0)
print('Deserialized model test loss:', score[0])
print('Deserialized model test accuracy:', score[1])

# Ensure accuracy persists after stripping the model
stripped_model = cluster.strip_clustering(loaded_clustered_model)
stripped_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy'])
stripped_model.save('stripped_model.h5')

# To acquire the stripped model,
# deserialize with clustering scope
with cluster.cluster_scope():
loaded_model = keras.models.load_model('stripped_model.h5')
# Checking that the stripped model's accuracy matches the clustered model
score = stripped_model.evaluate(x_test, y_test, verbose=0)
print('Stripped model test loss:', score[0])
print('Stripped model test accuracy:', score[1])

# Checking that the stripped model's accuracy matches the clustered model
score = loaded_model.evaluate(x_test, y_test, verbose=0)
print('Stripped Model Test loss:', score[0])
print('Stripped Model Test accuracy:', score[1])

def main(unused_argv):
if FLAGS.enable_eager:
print('Running in Eager mode.')
tf.compat.v1.enable_eager_execution()

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

if tf.keras.backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
(x_train, y_train), (x_test, y_test) = load_mnist_dataset()

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

sequential_model = build_sequential_model(input_shape)
functional_model = build_functional_model(input_shape)
models = [sequential_model, functional_model]
train_and_save(models, x_train, y_train, x_test, y_test)
# Build model
model = build_sequential_model()
# Train model
model = train_model(model, x_train, y_train, x_test, y_test)
# Cluster and fine-tune model
clustered_model = cluster_model(model, x_train, y_train, x_test, y_test)
Copy link
Contributor

@Ruomei Ruomei Oct 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a question than a comment:
Have you tried to apply the clustering wrapper on the inference graph rather than the training graph of a small example like mnist? If yes, what does the curve of the training loss look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by "inference graph" and "training graph". Could you explain this a bit further?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yes. An inference graph usually refers to a training graph after modifications targeted at inference. For instance, you can take a look at a TensorFlow legacy tool explaining some typical graph transforms (e.g. fold_batch_norms, merge_duplicate_nodes, remove_control_dependencies). I will write a summary on the internal site soon.

In our application context, I was just curious whether clustering behaves differently on a toy model, when tested with an inference graph rather than a training graph, in a similar way to a realistic model. If so, it could help our future debugging for upcoming new features with all the nice visualization you have done. We can check this later and this is not essential for this PR.

# Test clustered model (serialize/deserialize, strip clustering)
test_clustered_model(clustered_model, x_test, y_test)


if __name__ == '__main__':
Expand Down