Setup

In [93]:
! pip install -q tensorflow-model-optimization

In [94]:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras

Train model for MNIST without pruning

In [95]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)
model.summary()

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_5 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 conv2d_5 (Conv2D)           (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d_5 (MaxPooling  (None, 13, 13, 12)       0         
 2D)                                                             
                                                                 
 flatten_5 (Flatten)         (None, 2028)              0         
                                                                 
 dense_5 (Dense)             (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
__________

Evaluate Baseline Metrics

In [96]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)


Baseline test accuracy: 0.9793000221252441
Saved baseline model to: C:\Users\Jchap\AppData\Local\Temp\tmprlzoj1de.h5


Fine-tune model with pruning

In [97]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 4
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.95,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()


Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
 _5 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_conv2d_  (None, 26, 26, 12)       230       
 5 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d_5 (PruneLowMagnitude                                     
 )                                                               
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
 _5 (PruneLowMagnitude)                                          
                                                      

Fine-tune the model with pruning

In [98]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)


Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x1662a286950>

Evaluate model performance against baseline 

In [99]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)


Baseline test accuracy: 0.9793000221252441
Pruned test accuracy: 0.913100004196167


Now the Compression Begins

Creating a compressible model for Tensorflow

In [100]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)





Saved pruned Keras model to: C:\Users\Jchap\AppData\Local\Temp\tmplrl13ycz.h5


In [101]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)



INFO:tensorflow:Assets written to: C:\Users\Jchap\AppData\Local\Temp\tmpf_8f309e\assets


INFO:tensorflow:Assets written to: C:\Users\Jchap\AppData\Local\Temp\tmpf_8f309e\assets


Saved pruned TFLite model to: C:\Users\Jchap\AppData\Local\Temp\tmpyph5r2ra.tflite


Compress the models using gzip

In [102]:
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

In [103]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

_, tflite_model_file = tempfile.mkstemp('.tflite')

with open(tflite_model_file, 'wb') as f:
  f.write(tflite_model)

print('Saved pruned TFLite model to:', tflite_model_file)



INFO:tensorflow:Assets written to: C:\Users\Jchap\AppData\Local\Temp\tmp7m5b155p\assets


INFO:tensorflow:Assets written to: C:\Users\Jchap\AppData\Local\Temp\tmp7m5b155p\assets


Saved pruned TFLite model to: C:\Users\Jchap\AppData\Local\Temp\tmpsceg2xhy.tflite


In [104]:
print("Size of baseline Keras model: %.2f bytes" % (os.path.getsize(keras_file)))
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print(f"Zipping compression ratio of {(os.path.getsize(keras_file))/(get_gzipped_model_size(keras_file))}")
print()
print()
print("Size of Keras TFlite model: %.2f bytes" % (os.path.getsize(tflite_model_file)),f"({(os.path.getsize(keras_file))/(os.path.getsize(tflite_model_file))} times smaller than the baseline)")
print("Size of gzipped Keras TFlite model: %.2f bytes" % (get_gzipped_model_size(tflite_model_file)),f"({(get_gzipped_model_size(keras_file))/(get_gzipped_model_size(tflite_model_file)):.4f} times smaller than the zipped baseline)")
print(f"Zipping compression ratio of: {(os.path.getsize(tflite_model_file))/(get_gzipped_model_size(tflite_model_file))}")
print()
print()
print("Size of pruned Keras model: %.2f bytes" % (os.path.getsize(pruned_keras_file)),f"({(os.path.getsize(keras_file))/(os.path.getsize(pruned_keras_file))} times smaller than the baseline)")
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)),f"({(get_gzipped_model_size(keras_file))/(get_gzipped_model_size(pruned_keras_file)):.4f} times smaller than the zipped baseline)")
print(f"Zipping compression ratio of: {(os.path.getsize(pruned_keras_file))/(get_gzipped_model_size(pruned_keras_file))}")
print()
print()
print("Size of pruned TFlite model: %.2f bytes" % (os.path.getsize(pruned_tflite_file)),f"({(os.path.getsize(keras_file))/(os.path.getsize(pruned_tflite_file))} times smaller than the baseline)")
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)),f"({(get_gzipped_model_size(keras_file))/(get_gzipped_model_size(pruned_tflite_file)):.4f} times smaller than the zipped baseline)")
print(f"Zipping compression ratio of {(os.path.getsize(pruned_tflite_file))/(get_gzipped_model_size(pruned_tflite_file))}")
print()
print()
print(f"Final Complete Compression: {os.path.getsize(keras_file)/get_gzipped_model_size(pruned_tflite_file)} times the size of the baseline model")


Size of baseline Keras model: 98928.00 bytes
Size of gzipped baseline Keras model: 78253.00 bytes
Zipping compression ratio of 1.2642071230495955


Size of Keras TFlite model: 84896.00 bytes (1.1652845834903882 times smaller than the baseline)
Size of gzipped Keras TFlite model: 12190.00 bytes (6.4194 times smaller than the zipped baseline)
Zipping compression ratio of: 6.964397046759639


Size of pruned Keras model: 98928.00 bytes (1.0 times smaller than the baseline)
Size of gzipped pruned Keras model: 13019.00 bytes (6.0107 times smaller than the zipped baseline)
Zipping compression ratio of: 7.598740302634611


Size of pruned TFlite model: 84896.00 bytes (1.1652845834903882 times smaller than the baseline)
Size of gzipped pruned TFlite model: 12190.00 bytes (6.4194 times smaller than the zipped baseline)
Zipping compression ratio of 6.964397046759639


Final Complete Compression: 8.115504511894995 times the size of the baseline model
