In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

In [None]:
(X_train, y_train) , (X_test, y_test) = keras.datasets.mnist.load_data()

In [None]:
len(X_train)

In [None]:
len(X_test)

In [None]:
X_train[0].shape

In [None]:
plt.matshow(X_train[0])

In [None]:
y_train[0]

In [None]:
X_train = X_train / 255
X_test = X_test / 255

In [None]:
X_train_flattened = X_train.reshape(len(X_train), 28*28)
X_test_flattened = X_test.reshape(len(X_test), 28*28)

In [None]:
X_train_flattened.shape

In [None]:
#Using Flatten layer so that we don't have to call .reshape on input dataset
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(10, activation='sigmoid')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(X_train, y_train, epochs=5)

In [None]:
model.evaluate(X_test,y_test)

In [None]:
model.save("./saved_model/")

In [None]:
#(1) Post training quantization
#Without quantization
"""
Here we just convert the model to a less data model 
"""

converter = tf.lite.TFLiteConverter.from_saved_model("./saved_model")
tflite_model = converter.convert()

In [None]:
#With quantization

converter = tf.lite.TFLiteConverter.from_saved_model("./saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

In [None]:
len(tflite_model)

In [None]:
len(tflite_quant_model)

In [None]:
#You can see above that quantizated model is 1/4th the size of a non quantized model

with open("tflite_model.tflite", "wb") as f:
    f.write(tflite_model)
with open("tflite_quant_model.tflite", "wb") as f:
    f.write(tflite_quant_model)
    
#Once you have above files saved to a disk, check their sizes

In [None]:
#(2) Quantization aware training
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

q_aware_model.summary()

In [None]:
q_aware_model.fit(X_train, y_train, epochs=1)

In [None]:
q_aware_model.evaluate(X_test, y_test)

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_qaware_model = converter.convert()

In [None]:
len(tflite_qaware_model)

In [None]:
with open("tflite_qaware_model.tflite", 'wb') as f:
    f.write(tflite_qaware_model)