# Quantization and Pruning

In this notebook, I'll provide code for mobile optimization techniques. These enable reduced model size and latency which makes it ideal for edge and IOT devices. I will start by training a Keras model then compare its model size and accuracy after going through these techniques:

* post-training quantization
* quantization aware training
* weight pruning

Let's begin!

## Imports

In [1]:
import tensorflow as tf
import numpy as np
import os
import tempfile
import zipfile

<a name='utilities'>

## Utilities and constants

Let's first define a few string constants and utility functions to make our code easier to maintain.

In [3]:
# GLOBAL VARIABLES

# String constants for model filenames
FILE_WEIGHTS = 'baseline_weights.h5'
FILE_NON_QUANTIZED_H5 = 'non_quantized.h5'
FILE_NON_QUANTIZED_TFLITE = 'non_quantized.tflite'
FILE_PT_QUANTIZED = 'post_training_quantized.tflite'
FILE_QAT_QUANTIZED = 'quant_aware_quantized.tflite'
FILE_PRUNED_MODEL_H5 = 'pruned_model.h5'
FILE_PRUNED_QUANTIZED_TFLITE = 'pruned_quantized.tflite'
FILE_PRUNED_NON_QUANTIZED_TFLITE = 'pruned_non_quantized.tflite'

# Dictionaries to hold measurements
MODEL_SIZE = {}
ACCURACY = {}

In [4]:
# UTILITY FUNCTIONS

def print_metric(metric_dict, metric_name):
  '''Prints key and values stored in a dictionary'''
  for metric, value in metric_dict.items():
    print(f'{metric_name} for {metric}: {value}')


def model_builder():
  '''Returns a shallow CNN for training on the MNIST dataset'''

  keras = tf.keras

  # Define the model architecture.
  model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
  ])

  return model


def evaluate_tflite_model(filename, x_test, y_test):
  '''
  Measures the accuracy of a given TF Lite model and test set

  Args:
    filename (string) - filename of the model to load
    x_test (numpy array) - test images
    y_test (numpy array) - test labels

  Returns
    float showing the accuracy against the test set
  '''

  # Initialize the TF Lite Interpreter and allocate tensors
  interpreter = tf.lite.Interpreter(model_path=filename)
  interpreter.allocate_tensors()

  # Get input and output index
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Initialize empty predictions list
  prediction_digits = []

  # Run predictions on every image in the "test" dataset.
  for i, test_image in enumerate(x_test):
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == y_test).mean()

  return accuracy


def get_gzipped_model_size(file):
  '''Returns size of gzipped model, in bytes.'''
  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

## Download and Prepare the Dataset

I will be using the [MNIST](https://keras.io/api/datasets/mnist/) dataset which is hosted in [Keras Datasets](https://keras.io/api/datasets/). Some of the helper files in this notebook are made to work with this dataset so if you decide to switch to a different dataset, make sure to check if those helper functions need to be modified (e.g. shape of the Flatten layer in your model).

In [5]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## Baseline Model

I will first build and train a Keras model. This will be the baseline where I will be comparing the mobile optimized versions later on. This will just be a shallow CNN with a softmax output to classify a given MNIST digit. You can review the `model_builder()` function in the utilities at the top of this notebook but I also printed the model summary below to show the architecture.

I will also save the weights so I can reinitialize the other models later the same way. This is not needed in real projects but for this notebook, it would be good to have the same initial state later so I can compare the effects of the optimizations.

In [6]:
# Create the baseline model
baseline_model = model_builder()

# Save the initial weights for use later
baseline_model.save_weights(FILE_WEIGHTS)

# Print the model summary
baseline_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 12)        0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
____________________

In [7]:
# Setup the model for training
baseline_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
baseline_model.fit(train_images, train_labels, epochs=1, shuffle=False)



<keras.src.callbacks.History at 0x79b7154112a0>

In [8]:
# Get the baseline accuracy
_, ACCURACY['baseline Keras model'] = baseline_model.evaluate(test_images, test_labels)



Next, I will save the Keras model as a file and record its size as well.

In [9]:
# Save the Keras model
baseline_model.save(FILE_NON_QUANTIZED_H5, include_optimizer=False)

# Save and get the model size
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)

# Print records so far
print_metric(ACCURACY, "test accuracy")
print_metric(MODEL_SIZE, "model size in bytes")

test accuracy for baseline Keras model: 0.9634000062942505
model size in bytes for baseline h5: 98968


  saving_api.save_model(


### Convert the model to TF Lite format

In [10]:
def convert_tflite(model, filename, quantize=False):
  '''
  Converts the model to TF Lite format and writes to a file

  Args:
    model (Keras model) - model to convert to TF Lite
    filename (string) - string to use when saving the file
    quantize (bool) - flag to indicate quantization

  Returns:
    None
  '''

  # Initialize the converter
  converter = tf.lite.TFLiteConverter.from_keras_model(model)

  # Set for quantization if flag is set to True
  if quantize:
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

  # Convert the model
  tflite_model = converter.convert()

  # Save the model.
  with open(filename, 'wb') as f:
    f.write(tflite_model)

In [11]:
# Convert baseline model
convert_tflite(baseline_model, FILE_NON_QUANTIZED_TFLITE)

In [12]:
MODEL_SIZE['non quantized tflite'] = os.path.getsize(FILE_NON_QUANTIZED_TFLITE)

print_metric(MODEL_SIZE, 'model size in bytes')

model size in bytes for baseline h5: 98968
model size in bytes for non quantized tflite: 85012


In [13]:
ACCURACY['non quantized tflite'] = evaluate_tflite_model(FILE_NON_QUANTIZED_TFLITE, test_images, test_labels)

In [None]:
print_metric(ACCURACY, 'test accuracy')

### Post-Training Quantization

Now that I have the baseline metrics, I can now observe the effects of quantization. This process involves converting floating point representations into integer to reduce model size and achieve faster computation.

As shown in the `convert_tflite()` helper function earlier, I can easily do [post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) with the TF Lite API. I just need to set the converter optimization and assign an [Optimize](https://www.tensorflow.org/api_docs/python/tf/lite/Optimize) Enum.

I will set the `quantize` flag to do that and get the metrics again.

In [14]:
# Convert and quantize the baseline model
convert_tflite(baseline_model, FILE_PT_QUANTIZED, quantize=True)

In [15]:
# Get the model size
MODEL_SIZE['post training quantized tflite'] = os.path.getsize(FILE_PT_QUANTIZED)

print_metric(MODEL_SIZE, 'model size')

model size for baseline h5: 98968
model size for non quantized tflite: 85012
model size for post training quantized tflite: 24256


In [16]:
ACCURACY['post training quantized tflite'] = evaluate_tflite_model(FILE_PT_QUANTIZED, test_images, test_labels)

In [17]:
print_metric(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9634000062942505
test accuracy for non quantized tflite: 0.9634
test accuracy for post training quantized tflite: 0.9639


## Quantization Aware Training

When post-training quantization results in loss of accuracy that is unacceptable for your application, you can consider doing [quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training) before quantizing the model. This simulates the loss of precision by inserting fake quant nodes in the model during training. That way, your model will learn to adapt with the loss of precision to get more accurate predictions.

The [Tensorflow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization) provides a [quantize_model()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/quantize_model) method to do this quickly and you will see that below. But first, let's install the toolkit into the notebook environment.

In [18]:
# Install the toolkit
!pip install tensorflow_model_optimization

Collecting tensorflow_model_optimization
  Downloading tensorflow_model_optimization-0.7.5-py2.py3-none-any.whl (241 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m241.2/241.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow_model_optimization
Successfully installed tensorflow_model_optimization-0.7.5


In [19]:
import tensorflow_model_optimization as tfmot

# method to quantize a Keras model
quantize_model = tfmot.quantization.keras.quantize_model

# Define the model architecture.
model_to_quantize = model_builder()

# Reinitialize weights with saved file
model_to_quantize.load_weights(FILE_WEIGHTS)

# Quantize the model
q_aware_model = quantize_model(model_to_quantize)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_reshape_1 (QuantizeW  (None, 28, 28, 1)         1         
 rapperV2)                                                       
                                                                 
 quant_conv2d_1 (QuantizeWr  (None, 26, 26, 12)        147       
 apperV2)                                                        
                                                                 
 quant_max_pooling2d_1 (Qua  (None, 13, 13, 12)        1         
 ntizeWrapperV2)                                                 
                                                                 
 quant_flatten_1 (QuantizeW  (None, 2028)             

In [20]:
# Train the model
q_aware_model.fit(train_images, train_labels, epochs=1, shuffle=False)



<keras.src.callbacks.History at 0x79b71599d600>

In [21]:
# Reinitialize the dictionary
ACCURACY = {}

# Get the accuracy of the quantization aware trained model (not yet quantized)
_, ACCURACY['quantization aware non-quantized'] = q_aware_model.evaluate(test_images, test_labels, verbose=0)
print_metric(ACCURACY, 'test accuracy')

test accuracy for quantization aware non-quantized: 0.9635999798774719


In [24]:
# Convert and quantize the model.
convert_tflite(q_aware_model, FILE_QAT_QUANTIZED, quantize=True)

# Get the accuracy of the quantized model
ACCURACY['quantization aware quantized'] = evaluate_tflite_model(FILE_QAT_QUANTIZED, test_images, test_labels)
print_metric(ACCURACY, 'test accuracy')



test accuracy for quantization aware non-quantized: 0.9635999798774719
test accuracy for quantization aware quantized: 0.9636


## Pruning

Let's now move on to another technique for reducing model size: [Pruning](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras). This process involves zeroing out insignificant (i.e. low magnitude) weights. The intuition is these weights do not contribute as much to making predictions so we can remove them and get the same result. Making the weights sparse helps in compressing the model more efficiently and you will see that in this section.

The Tensorflow Model Optimization Toolkit again has a convenience method for this. The [prune_low_magnitude()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) method puts wrappers in a Keras model so it can be pruned during training. I will pass in the baseline model that I already trained earlier. You will notice that the model summary show increased params because of the wrapper layers added by the pruning method.

We can set how the pruning is done during training. Below, I will use [PolynomialDecay](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PolynomialDecay) to indicate how the sparsity ramps up with each step. Another option available in the library is [Constant Sparsity](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/ConstantSparsity).

In [25]:
# Get the pruning method
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define pruning schedule.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

# Pass in the trained baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                        

In [26]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[ 0.32880077, -0.00370286, -0.43661624, -0.7047039 ,
           0.25139543, -0.03380103,  0.12268101, -0.47647282,
           0.06151805,  0.19884391, -0.09929218,  0.64335364]],

        [[ 0.44360656,  0.1415229 ,  0.05076738, -0.22037843,
           0.33123186,  0.28065774,  0.16067812, -0.4581575 ,
           0.10179768,  0.41382894,  0.06424277, -0.54322696]],

        [[ 0.2417517 , -0.15722294,  0.38414168,  0.39335033,
          -0.06849516,  0.4636034 , -0.13752379, -0.4403433 ,
           0.22671011,  0.38459358,  0.04796667, -0.33809617]]],


       [[[-0.1376045 ,  0.07523245, -0.3559213 , -0.6807172 ,
           0.09241283, -0.15107277, -0.13837588, -0.01652741,
           0.2560323 ,  0.35705173,  0.28778088, -0.59544927]],

        [[ 0.3113292 ,  0.28261712,  0.29757544,  0.2872469 ,
           0.31608188,  0.17711316,  0.10828544,  0.19653368,
          -0.02341902,  0.2382372 ,  0.112398

In [27]:
# Callback to update pruning wrappers at each step
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]

# Train and prune the model
model_for_pruning.fit(train_images, train_labels,
                  epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x79b6f80cbd00>

In [28]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-0.        , -0.        , -0.        , -0.7656396 ,
           0.        , -0.        , -0.        , -1.0429403 ,
          -0.        , -0.        , -0.        ,  0.8965841 ]],

        [[ 0.88872886, -0.        ,  0.        ,  0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  1.1694382 , -0.        , -0.8108421 ]],

        [[-0.        , -0.        ,  0.85616434,  0.        ,
           0.        ,  0.8353219 ,  0.        , -0.        ,
           0.        , -0.        , -0.        , -0.        ]]],


       [[[-0.        , -0.        ,  0.        , -0.68185574,
           0.        , -0.        , -0.        , -0.        ,
          -0.        , -0.        , -0.        , -1.3538901 ]],

        [[-0.        , -0.        ,  0.        ,  0.        ,
           0.        , -0.        , -0.        , -0.        ,
          -0.        , -0.        , -0.      

After pruning, we can remove the wrapper layers to have the same layers and params as the baseline model.

In [29]:
# Remove pruning wrappers
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 12)        0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
____________________

In [30]:
# Preview model weights (index 1 earlier is now 0 because pruning wrappers were removed)
model_for_export.weights[0]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-0.        , -0.        , -0.        , -0.7656396 ,
           0.        , -0.        , -0.        , -1.0429403 ,
          -0.        , -0.        , -0.        ,  0.8965841 ]],

        [[ 0.88872886, -0.        ,  0.        ,  0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  1.1694382 , -0.        , -0.8108421 ]],

        [[-0.        , -0.        ,  0.85616434,  0.        ,
           0.        ,  0.8353219 ,  0.        , -0.        ,
           0.        , -0.        , -0.        , -0.        ]]],


       [[[-0.        , -0.        ,  0.        , -0.68185574,
           0.        , -0.        , -0.        , -0.        ,
          -0.        , -0.        , -0.        , -1.3538901 ]],

        [[-0.        , -0.        ,  0.        ,  0.        ,
           0.        , -0.        , -0.        , -0.        ,
          -0.        , -0.        , -0.      

You will notice below that the pruned model will have the same file size as the baseline_model when saved as H5. This is to be expected. The improvement will be noticeable when you compress the model as will be shown in the cell after this.

In [31]:
# Save Keras model
model_for_export.save(FILE_PRUNED_MODEL_H5, include_optimizer=False)

# Get uncompressed model size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = os.path.getsize(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, 'model_size in bytes')

  saving_api.save_model(


model_size in bytes for baseline h5: 98968
model_size in bytes for pruned non quantized h5: 98968


In [32]:
# Get compressed size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = get_gzipped_model_size(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = get_gzipped_model_size(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 78077
gzipped model size in bytes for pruned non quantized h5: 25831


In [33]:
# Convert and quantize the pruned model.
pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)

# Compress and get the model size
MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)
print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 78077
gzipped model size in bytes for pruned non quantized h5: 25831
gzipped model size in bytes for pruned quantized tflite: 8593


In [34]:
# Get accuracy of pruned Keras and TF Lite models
ACCURACY = {}

_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(test_images, test_labels)
ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, test_images, test_labels)

print_metric(ACCURACY, 'accuracy')

accuracy for pruned model h5: 0.9710999727249146
accuracy for pruned and quantized tflite: 0.9713
