In [1]:
# import everything

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [2]:
# load the data

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# [60000, 28, 28], [10000, 28, 28], [60000, ], [10000, ]

In [3]:
# clean and organize the data

x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
x_train = x_train.reshape(60000, 28 * 28)/255.0
x_test = x_test.reshape(10000, 28 * 28)/255.0

nClass = 10
y_train = keras.utils.to_categorical(y_train, nClass)
y_test = keras.utils.to_categorical(y_test, nClass)

In [4]:
# load the model trained from main.ipynb

model = keras.models.load_model('handwritten_model', compile=True)

# Convert to TF Lite Model

In [5]:
# model conversion code taken from Tensorflow website

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(x_train).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

INFO:tensorflow:Assets written to: /tmp/tmpp1mgmzve/assets




In [6]:
# check that input and output have been quantized to uint8 type

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

input:  <class 'numpy.uint8'>
output:  <class 'numpy.uint8'>


In [7]:
# save the converted model

import pathlib

tflite_model_quant_file = pathlib.Path.cwd()/"mnist_model_quant.tflite"
with open('tflite_model_quant_file', 'wb') as f:
  f.write(tflite_model_quant)
tflite_model_quant_file.write_bytes(tflite_model_quant)

675600

In [8]:
# helper function to run inference on a TFLite model

def run_tflite_model(tflite_file, test_image_indices):
  global x_test

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = x_test[test_image_index]
    test_label = y_test[test_image_index]

    # Check if the input type is quantized, then rescale input data to uint8
    if input_details['dtype'] == np.uint8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    
    # reshape test_image
    test_image = tf.reshape(test_image, [1, 784])
    
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

In [9]:
# helper function that gets the classification

def getClass(lst):
    return np.argmax(lst)

In [10]:
# helper function to evaluate a TFLite model on all images

def evaluate_model(tflite_file, model_type):
  global x_test
  global y_test

  test_image_indices = range(x_test.shape[0])
  predictions = run_tflite_model(tflite_file, test_image_indices)
  tmp = [getClass(e) for e in y_test]

  accuracy = (np.sum(tmp == predictions) * 100) / len(x_test)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(x_test)))

In [11]:
# evaluate model that was converted

evaluate_model(tflite_model_quant_file, model_type="Quantized")

Quantized model accuracy is 97.8900% (Number of test samples=10000)


In [12]:
# evaluate model that was saved

evaluate_model("mnist_model_quant.tflite", model_type="Quantized")

Quantized model accuracy is 97.8900% (Number of test samples=10000)
