Model Quantization Tutorial

TensorflowLite is a set of tools developed by Google that enables faster inference and smaller models for edge applications.

In many use cases inference time is a huge requirement, also in edge applications such as mobile, IoT, ... memory and battery concerns play a big role as well.

There are different types of tecniques that enable this, the main two are:

  • Post Trainning Quantization
  • During trainning Quantization

The difference between thw two is that the first reduces the size of the weights from float32 to smallwer sizes (up to int8) after trainning while the latter does this during training.

The main advantage of this last technique is that, in some circunstances allows the model to adjust to the new weights precises which enables a smaller model accuracy deteoration.

Post Training Example


!pip install tensorflow
# Package installation

import logging

import tensorflow as tf
from tensorflow import keras
import numpy as np
import pathlib
Train tensorflow model

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),

# Train the digit classification model
# According to the docs we should use crossentropy when Use this crossentropy loss function when there are two or more label classes.
  validation_data=(test_images, test_labels)
1875/1875 [==============================] - 12s 6ms/step - loss: 0.2881 - accuracy: 0.9185 - val_loss: 0.1389 - val_accuracy: 0.9601

Convert to tensorflow lite

# here a special type of file is used to save the model. Just by converting to tf.lite we
# have gains in model size 
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
Check size

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

tflite_model_file = tflite_models_dir/"mnist_model.tflite"

Perform Quantization

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite"
ls -lh {tflite_models_dir}
Inference time

# Load model intro intrepreter
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))
# Test using one image
test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
predictions = interpreter.get_tensor(output_index)
import matplotlib.pylab as plt

template = "True:{true}, predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[0]),


# Compare both models 
# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy
# Repeat the evaluation on the dynamic range quantized model to obtain:

Inference time

results = {}

# Standard model 
import time
model.predict(test_images[0].reshape(1,28,28)) # First predict is slower
a = time.time()
b = time.time()
print(f"Standard batch predict took {(b-a)*1000} ms")

# Faster predict
a = time.time()
b = time.time()
print(f"Faster standard predict took {(b-a)*1000} ms")

# TFLite model WITHOUT Quantization
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
a = time.time()
test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
predictions = interpreter.get_tensor(output_index)
b = time.time()
print(f"TFLite without Quantization predict took {(b-a)*1000} ms")

# With Quantization
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))
a = time.time()
test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
predictions = interpreter.get_tensor(output_index)
b = time.time()
print(f"TFLite with Quantization predict took {(b-a)*1000} ms")
results = {"model.predict()": 56, "model()":3.37,"TfLite":0.91, "TfLiteQuantization":0.69}
import numpy as np
import matplotlib.pyplot as plt
predicts = list(results.keys())
values = list(results.values())
fig = plt.figure(figsize=(10,8))
# creating the bar plot, values, color ='maroon',
        width = 0.4)
plt.xlabel("Types of predict")
plt.ylabel("Inference time in ms")
plt.title("MNIST inference time for 96% accuracy model")



Comparing the values above we see that the accuracy of the model was not affected significantly.

The main issues that I found so far is the case when the model have layers with dynamic inputs. In that case we need to manual set the shema of the interpreter layers.