In [1]:
!pip install tensorflow_model_optimization

Collecting tensorflow_model_optimization
[?25l  Downloading https://files.pythonhosted.org/packages/78/8f/f6969dc64709c5c5e22cfd7057a83adbc927e6855a431b234168222cbf03/tensorflow_model_optimization-0.6.0-py2.py3-none-any.whl (211kB)
[K     |█▌                              | 10kB 12.7MB/s eta 0:00:01[K     |███                             | 20kB 8.1MB/s eta 0:00:01[K     |████▋                           | 30kB 5.7MB/s eta 0:00:01[K     |██████▏                         | 40kB 5.0MB/s eta 0:00:01[K     |███████▊                        | 51kB 2.7MB/s eta 0:00:01[K     |█████████▎                      | 61kB 3.0MB/s eta 0:00:01[K     |██████████▉                     | 71kB 3.3MB/s eta 0:00:01[K     |████████████▍                   | 81kB 3.5MB/s eta 0:00:01[K     |██████████████                  | 92kB 3.7MB/s eta 0:00:01[K     |███████████████▌                | 102kB 4.0MB/s eta 0:00:01[K     |█████████████████               | 112kB 4.0MB/s eta 0:00:01[K     |██████

In [2]:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot
import zipfile

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')

In [4]:
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [5]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [6]:
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

In [7]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [8]:
model.fit(train_images,
          train_labels,
          epochs=5,
          validation_split=0.1)

_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print('Baseline test accuracy: ', baseline_model_accuracy)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Baseline test accuracy:  0.982200026512146


In [9]:
keras_file = './baseline.h5'
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Save baseline model to: ', keras_file)

Save baseline model to:  ./baseline.h5


In [10]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
batch_size = 128
epochs = 10
validation_split = 0.1

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

In [11]:
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                             final_sparsity=0.8,
                                                             begin_step=0,
                                                             end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])
model_for_pruning.summary()



Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape  (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10)                40572     
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________


In [12]:
logdir = './logs'
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir, update_freq='epoch')
]

model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)


Epoch 1/10
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f192d617210>

In [13]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
print('Baseline test accuracy: ', baseline_model_accuracy)
print('Pruned test accuracy: ', model_for_pruning_accuracy)

Baseline test accuracy:  0.982200026512146
Pruned test accuracy:  0.9765999913215637


In [14]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
pruned_keras_file = './baseline_pruned.h5'
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Save pruned keras model to: ', pruned_keras_file)

Save pruned keras model to:  ./baseline_pruned.h5


In [15]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()
pruned_tflite_file = './baseline_pruned.tflite'
with open(pruned_tflite_file, 'wb') as f:
    f.write(pruned_tflite_model)

print('save pruned TFLite model to: ', pruned_keras_file)

INFO:tensorflow:Assets written to: /tmp/tmp2g3ii3oz/assets
save pruned TFLite model to:  ./baseline_pruned.h5


In [16]:
def get_gzipped_model_size(file):
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)
    return os.path.getsize(zipped_file)

In [17]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

Size of gzipped baseline Keras model: 78195.00 bytes
Size of gzipped pruned Keras model: 25938.00 bytes
Size of gzipped pruned TFlite model: 25437.00 bytes


In [18]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
quantized_and_pruned_tflite_file = './baseline_pruned_quantized.tflite'





INFO:tensorflow:Assets written to: /tmp/tmpnqbn9zfd/assets


INFO:tensorflow:Assets written to: /tmp/tmpnqbn9zfd/assets


In [19]:
with open(quantized_and_pruned_tflite_file, 'wb') as f:
    f.write(quantized_and_pruned_tflite_model)

In [20]:
print('Save quantized and pruned TFLite model to: ', quantized_and_pruned_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

Save quantized and pruned TFLite model to:  ./baseline_pruned_quantized.tflite
Size of gzipped baseline Keras model: 78195.00 bytes
Size of gzipped pruned and quantized TFlite model: 8325.00 bytes


In [21]:
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]['index']
    output_index = interpreter.get_output_details()[0]['index']

    # Run predictions on ever y image in the test dataset
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        # Pre-processing: add batch dimension and convert to float32 to match with the model's input data format
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest probability
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    # Compare prediction results with ground truth labels to calculate accuracy
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [22]:
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

In [24]:
test_accuracy = evaluate_model(interpreter)

In [25]:
print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)

Pruned and quantized TFLite test_accuracy: 0.9766
Pruned TF test accuracy: 0.9765999913215637
