In [1]:
!pip install -q tensorflow-model-optimization

In [2]:
from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Softmax, Dense, ReLU, BatchNormalization
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [3]:
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model

In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [5]:
# Convert y_train into one-hot format
temp = []
for i in range(len(y_train)):
    temp.append(to_categorical(y_train[i], num_classes=10))
y_train = np.array(temp)
# Convert y_test into one-hot format
temp = []
for i in range(len(y_test)):    
    temp.append(to_categorical(y_test[i], num_classes=10))
y_test = np.array(temp)

In [6]:
#reshaping
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

In [7]:
inputs = Input(shape=(28,28,1))
# out = Lambda(lambda x: x/100)(inputs)
out = Conv2D(4, 3, use_bias=False)(inputs)
out = BatchNormalization()(out)
out = ReLU()(out)
# out = Lambda(lambda x: x**2+x)(out)
out = AveragePooling2D()(out)
# out = Lambda(lambda x: x*4)(out)
out = Conv2D(8, 3, use_bias=False)(out)
out = BatchNormalization()(out)
out = ReLU()(out)
# out = Lambda(lambda x: x**2+x)(out)
out = AveragePooling2D()(out)
# out = Lambda(lambda x: x*4)(out)
out = Flatten()(out)
out = Dense(10, activation=None)(out)
out = Softmax()(out)
model = Model(inputs, out)

In [8]:
q_aware_model = quantize_model(model)

In [9]:
q_aware_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 quantize_layer (QuantizeLay  (None, 28, 28, 1)        3         
 er)                                                             
                                                                 
 quant_conv2d (QuantizeWrapp  (None, 26, 26, 4)        45        
 erV2)                                                           
                                                                 
 quant_batch_normalization (  (None, 26, 26, 4)        17        
 QuantizeWrapperV2)                                              
                                                                 
 quant_re_lu (QuantizeWrappe  (None, 26, 26, 4)        3         
 rV2)                                                        

In [10]:
q_aware_model.compile(
    loss='categorical_crossentropy',
    optimizer=SGD(learning_rate=0.01, momentum=0.9),
    metrics=['acc']
    )

In [11]:
q_aware_model.fit(X_train, y_train, epochs=15, batch_size=32, validation_data=(X_test, y_test))

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<keras.callbacks.History at 0x7f8045540d10>

In [12]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

model = converter.convert()



In [13]:
'''
Create interpreter, allocate tensors
'''
tflite_interpreter = tf.lite.Interpreter(model_content=model)
tflite_interpreter.allocate_tensors()

'''
Check input/output details
'''
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

'''
This gives a list of dictionaries. 
'''
tensor_details = tflite_interpreter.get_tensor_details()

for dict in tensor_details:
    i = dict['index']
    tensor_name = dict['name']
    scales = dict['quantization_parameters']['scales']
    zero_points = dict['quantization_parameters']['zero_points']
    tensor = tflite_interpreter.tensor(i)()

    print(i, type, tensor_name, scales.shape, zero_points.shape, tensor.shape)
    # print(tensor)

== Input details ==
name: serving_default_input_1:0
shape: [ 1 28 28  1]
type: <class 'numpy.float32'>

== Output details ==
name: StatefulPartitionedCall:0
shape: [ 1 10]
type: <class 'numpy.float32'>
0 <class 'type'> serving_default_input_1:0 (0,) (0,) (1, 28, 28, 1)
1 <class 'type'> model/quant_flatten/Const (0,) (0,) (2,)
2 <class 'type'> model/quant_dense/BiasAdd/ReadVariableOp (1,) (1,) (10,)
3 <class 'type'> model/quant_batch_normalization_1/FusedBatchNormV3 (8,) (8,) (8,)
4 <class 'type'> model/quant_batch_normalization/FusedBatchNormV3 (4,) (4,) (4,)
5 <class 'type'> model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars;model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars/ReadVariableOp;model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars/ReadVariableOp_1 (1,) (1,) (1, 28, 28, 1)
6 <class 'type'> model/quant_conv2d/Conv2D;model/quant_conv2d/LastValueQuant/FakeQuantWithMinMaxVarsPerChannel (4,) (4,) (4, 3, 3, 1)
7 <class 'type'> model/quant_re_lu/Re