In [13]:
import tempfile
import os

import tensorflow as tf
import numpy as np
import keras

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [46]:
from keras.datasets import mnist

#Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the shape of images

train_images = train_images/255.0
test_images = test_images/255.0
train_labels

array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

In [15]:
from keras.layers import InputLayer, Reshape, Conv2D, MaxPooling2D, Flatten, Dense

# Define model architecture
model = keras.Sequential([
    InputLayer(input_shape=(28,28)),
    Reshape(target_shape=(28,28,1)),
    Conv2D(filters=12, kernel_size=(3,3), activation='relu'),
    MaxPooling2D(pool_size=(2,2)),
    Flatten(),
    Dense(units=10)
])

# Train model
model.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

model.fit(train_images,train_labels,epochs=4,validation_split=0.1)


Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x184955525c0>

In [16]:
_, model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print("Baseline model accuracy: ", model_accuracy)

Baseline model accuracy:  0.9775999784469604


In [17]:
keras.models.save_model(model, "mnist_model.h5", include_optimizer=False)

# Fine-tune pre-trained model with pruning

In [18]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

#Compute end step to finish pruning after 2 epochs
batch_size = 128
epochs = 2
validation_split = 0.1

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images/batch_size).astype(np.int32) * epochs

# Define model for pruning
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.80, begin_step=0, end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# prune_low_magnitude requires a recompile

model_for_pruning.compile(optimizer='adam', loss= keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model_for_pruning.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
 _1 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_conv2d_  (None, 26, 26, 12)       230       
 1 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d_1 (PruneLowMagnitude                                     
 )                                                               
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
 _1 (PruneLowMagnitude)                                          
                                                      

# Train and evaluate the model against baseline model

In [19]:
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries("log_pruning_summaries")
]

model_for_pruning.fit(train_images,train_labels, batch_size=batch_size, epochs=epochs, validation_split=validation_split, callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x184964737c0>

In [20]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)

print("Baseline test accuracy: ", model_accuracy)
print("Pruned test accuracy: ", model_for_pruning_accuracy)

Baseline test accuracy:  0.9775999784469604
Pruned test accuracy:  0.9692999720573425


In [21]:
#%tensorboard --logdir={log_pruning_summaries}

In [22]:
# strip_pruning removes every tf.Variable that pruning only needs during training, which would otherwise add to model size during inference
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

pruned_keras_file = "pruned_mnist_model.h5"
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)




In [23]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

pruned_tflite_file = "pruned_tflite_mnist_model.tflite"

with open(pruned_tflite_file, "wb") as f:
    f.write(pruned_tflite_model)
    



INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmpx492_md5\assets


INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmpx492_md5\assets


In [26]:
import gzip
import os 
import zipfile
def get_gzipped_model_size(file, file_name):
    # Returns size of gzipped model, in bytes
    zipped_file = file_name + ".zip"

    with zipfile.ZipFile(zipped_file, "w", compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)


In [29]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
quantized_and_pruned_tflite_file = "quantized_pruned_mnist_model.tflite"

with open(quantized_and_pruned_tflite_file, "wb") as f:
    f.write(quantized_and_pruned_tflite_model)



INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmpib2abzfi\assets


INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmpib2abzfi\assets


In [30]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size("mnist_model.h5", "zip_mnist_model")))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size("pruned_mnist_model.h5", "zip_pruned_mnist_model")))
print("Size of gzipped pruned TFLite model: %.2f bytes" % (get_gzipped_model_size("pruned_tflite_mnist_model.tflite", "zip_pruned_tflite_mnist_model")))
print("Size of gzipped quantized pruned TFLite model: %.2f bytes" % (get_gzipped_model_size("quantized_pruned_mnist_model.tflite", "zip_quantized_pruned_mnist_model")))

Size of gzipped baseline Keras model: 78198.00 bytes
Size of gzipped pruned Keras model: 25858.00 bytes
Size of gzipped pruned TFLite model: 25172.00 bytes
Size of gzipped quantized pruned TFLite model: 8353.00 bytes


# See persistence of accuracy from TF to TFLite


In [43]:
import numpy as np

def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on over y image in the test dataset
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print("Evaluated on {n} results so far.".format(n=i))
        
        # Pre-processing: add batch dimension and convert to float32 to match with the model's input data format
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest probability
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print("\n")

    # Compare prediction result with ground truth labels to calculate accuracy
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy
        

In [44]:
for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print("Evaluated on {n} results so far.".format(n=i))

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


In [45]:
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print("Pruned and quantized TFLite test_accuracy: ", test_accuracy)
print("Pruned TF test accuracy: ", model_for_pruning_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy:  0.9693
Pruned TF test accuracy:  0.9692999720573425
