In [3]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.applications import MobileNet
import tensorflow_model_optimization as tfmot


In [4]:

# Load the data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Reshape the data to 32x32x3 and normalize it
x_train = tf.image.grayscale_to_rgb(tf.expand_dims(x_train, axis=-1))
x_train = tf.image.resize(x_train, (32, 32))
x_train = x_train / 255.0
x_test = tf.image.grayscale_to_rgb(tf.expand_dims(x_test, axis=-1))
x_test = tf.image.resize(x_test, (32, 32))
x_test = x_test / 255.0


In [5]:

# Define the model
inputs = Input(shape=(32, 32, 3))
base_model = MobileNet(input_tensor=inputs, include_top=False, weights=None, pooling='avg')
x = Dropout(0.25)(base_model.output)
outputs = Dense(10, activation='softmax')(x)
model = keras.models.Model(inputs, outputs)

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test))



<keras.callbacks.History at 0x20887a34c40>

In [6]:

# Apply QAT
q_aware_model = tfmot.quantization.keras.quantize_model(model)
q_aware_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

q_aware_model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test))




<keras.callbacks.History at 0x208ba331490>

In [9]:
q_aware_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 quantize_layer (QuantizeLay  (None, 32, 32, 3)        3         
 er)                                                             
                                                                 
 quant_conv1 (QuantizeWrappe  (None, 16, 16, 32)       929       
 rV2)                                                            
                                                                 
 quant_conv1_bn (QuantizeWra  (None, 16, 16, 32)       129       
 pperV2)                                                         
                                                                 
 quant_conv1_relu (QuantizeW  (None, 16, 16, 32)       3         
 rapperV2)                                                   

In [7]:
_ , float_acc = model.evaluate(x_test, y_test)
_ , qat_acc = q_aware_model.evaluate(x_test, y_test)




In [8]:
print(f"Model Test Accuracy (without Quantization) \t: {float_acc:.4f}")
print(f"QAT Model Test Accuracy (without Quantization) \t: {qat_acc:.4f}")

Model Test Accuracy (without Quantization) 	: 0.9546
QAT Model Test Accuracy (without Quantization) 	: 0.9814
