## Overview

Welcome to an end-to-end MNIST example for *quantization-aware training*.

### Other Pages
For an introduction to what quantization-aware-training is and to determine if you want to use it, see the [overview page](https://www.tensorflow.org/model_optimization/guide/quantization/training.md).

To quickly finding the APIs you need for your use case (beyond the single path in this example), see the
[comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md).

### Contents

In this tutorial, we will:

1.   Train a tf.keras MNIST from scratch. TODO: use pre-trained instead. We cannot use a GPU given TFLite part also.

2.   Apply quantization-aware training API to MNIST, see the accuracy, and
     export a quantization-aware model.

3.   Use the model to create an actually quantized model for the TFLite
     backend. 

4.   See the 4x model size reduction and the persistence of accuracy in 
     TFLite. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).

In [0]:
! pip uninstall -y tensorflow
! pip install -q tf-nightly==2.2.0.dev20200305
! pip install -q --extra-index-url=https://test.pypi.org/simple/ tensorflow-model-optimization==0.3.0.dev3

import tempfile

import tensorflow as tf

from tensorflow import keras

Uninstalling tensorflow-1.15.0:
  Successfully uninstalled tensorflow-1.15.0
[K     |████████████████████████████████| 516.8MB 32kB/s 
[K     |████████████████████████████████| 460kB 60.9MB/s 
[K     |████████████████████████████████| 2.9MB 49.8MB/s 
[K     |████████████████████████████████| 2.8MB 54.9MB/s 
[K     |████████████████████████████████| 778kB 31.0MB/s 
[31mERROR: tensor2tensor 1.14.1 requires bz2file, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires gevent, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires gunicorn, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires kfac, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires mesh-tensorflow, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires pypng, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires tensorflow-datasets, which is not installed.[0m
[31mERROR: tensor2tensor 1.14.1 requires tensorflow-gan, wh

# Train a MNIST model without quantization-aware-training

In [0]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture. 
# TODO(tfmot): change back to Sequential once CL is submitted.
inp = keras.layers.Input(shape=(28, 28))
x = keras.layers.Reshape(target_shape=(28, 28, 1))(inp)
x = keras.layers.Conv2D(32, 5, padding='same', activation='relu')(x)
x = keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
x = keras.layers.Conv2D(64, 5, padding='same', activation='relu')(x)
x = keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(1024, activation='relu')(x)
x = keras.layers.Dropout(0.4)(x)
out = keras.layers.Dense(10, activation='softmax')(x)

model = keras.models.Model([inp], [out])

# Train the digit classification model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_data=(test_images, test_labels)
)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


<tensorflow.python.keras.callbacks.History at 0x7f40803404a8>

# Apply quantization-aware-training to the pre-trained MNIST.


## Define the model

We apply quantization-aware training to the whole model and see this in the model summaries. All layers are now prefixed by "quant". In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements.

In [0]:
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

q_aware_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 28, 28, 32)        899       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 14, 14, 32)        1         
_________________________________________________________________
quant_conv2d_1 (QuantizeWrap (None, 14, 14, 64)        51395     
_________________________________________________________________
quant_max_pooling2d_1 (Quant (None, 7, 7, 64)          1     

## Train and evaluate the model against baseline

In [0]:
q_aware_model.fit(train_images, train_labels, batch_size=500, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x7f40664b9710>

For this example, the test accuracies are similar before and after.

In [0]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9879000186920166
Quant test accuracy: 0.9915000200271606


# Create quantized model for TFLite backend

After this, we have an actually quantized model.

In [0]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

# See 4x model size reduction and persistence of accuracy in TFLite

We define a helper function to evaluate the TF Lite model the test dataset.

In [0]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy

We evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.

In [0]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)
print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

Quant TFLite test_accuracy: 0.9912
Quant TF test accuracy: 0.9915000200271606


We create a float TFLite model and then see that the quantized TFLite model
is 4x smaller.

In [0]:
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

import os 
print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

Float model in Mb: 12.494964599609375
Quantized model in Mb: 3.1320648193359375


# Conclusion

In this tutorial, we showed you how to create *quantization-aware models* with the TensorFlow Model Optimization Toolkit API and then *quantized models* in the TFLite backend. 

We saw a 4x model size compression benefit for MNIST, with minimal accuracy
difference. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).

We encourage you to try this new capability, which can be particularly important for deployment in resource-constraint environments. 

