<a href="https://colab.research.google.com/github/z4ziad/TensorFlow/blob/main/My_QAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quantization Aware Training Notebook Example
This notebook shows the effect if any of quantization-aware training on a simple ConvNet model trained on the MNIST dataset.

Install TensorFlow model optimization

In [None]:
! pip install -q tensorflow-model-optimization

In [None]:
import tempfile
import os

import tensorflow as tf
print("TF versoin: ", tf.__version__)

from tensorflow import keras
import tensorflow_model_optimization as tfmot
print("TensorFlow Model Optimizaiton version:", tfmot.__version__)

Import MNIST dataset and normalize it

In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

Build a model with quantization annotations to use for quantization aware training

In [None]:
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer

# Define the model architecture.
model = keras.Sequential([
  quantize_annotate_layer(keras.layers.InputLayer(input_shape=(28, 28))),
  quantize_annotate_layer(keras.layers.Reshape(target_shape=(28, 28, 1))),
  quantize_annotate_layer(keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu')),
  #keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation=tf.nn.relu),
  quantize_annotate_layer(keras.layers.MaxPooling2D(pool_size=(2, 2))),
  quantize_annotate_layer(keras.layers.Flatten()),
  quantize_annotate_layer(keras.layers.Dense(10))
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.summary()

print("trainin baseline model...")
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
  #validation_data=(test_images, test_labels)
)

Quantization awareness has not been applied yet. To make the model quantization-aware, we need we get a new quantization aware model by calling `quantize_apply(model)`  

In [None]:
quantized_aware_model = tfmot.quantization.keras.quantize_apply(model)
# We need to recompile the model after applying quantization awarness
quantized_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

quantized_aware_model.summary()

print("training quantized model...")
quantized_aware_model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
  #validation_data=(test_images, test_labels)
)

Let's get the accuracy on the test_mages

In [None]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = quantized_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Now, let convert the quantized-aware model to TFLite model

In [None]:
print("converting quantized model to tflite...")
converter = tf.lite.TFLiteConverter.from_keras_model(quantized_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

Finally, let's evalute the accuracy of the TFLite model on the test_images dataset:

In [None]:
test_tflite_accuracy = evaluate_model(interpreter)
print('Quant TFLite test_accuracy:', test_tflite_accuracy)

Now compare the TFLite test_accuracy with the base model and the quantization-aware-model:

In [None]:
print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
print('Quant TFLite test_accuracy:', test_tflite_accuracy)

As a sanity check, let's build the same model without annotation and check its accuracy

In [None]:
# Define the model architecture.
sanity_model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),
  #keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
sanity_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

sanity_model.fit(
  train_images,
  train_labels,
  epochs=1,
  #validation_split=0.1,
  validation_data=(test_images, test_labels)
)

_, sanity_model_accuracy = sanity_model.evaluate(
    test_images, test_labels, verbose=0)

print('Sanity test accuracy:', sanity_model_accuracy)