In [1]:
import tempfile
import os
import tensorflow as tf
import keras

In [2]:
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

In [5]:
from keras.layers import Dense, Flatten, MaxPooling2D, InputLayer, Reshape, Conv2D

# Define model architecture
model = keras.Sequential([
    InputLayer(input_shape=(28,28)),
    Reshape(target_shape=(28,28,1)),
    Conv2D(filters=12, kernel_size=(3,3), activation='relu'),
    MaxPooling2D(pool_size=(2,2)),
    Flatten(),
    Dense(units=10)
])

# Train model
model.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

model.fit(train_images,train_labels,epochs=1,validation_split=0.1)

model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_1 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 conv2d_1 (Conv2D)           (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 13, 13, 12)       0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 2028)              0         
                                                                 
 dense_1 (Dense)             (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
__________________________________________________

In [6]:
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for quantization aware
q_aware_model = quantize_model(model)

# 'quantized_model' requires a recompile
q_aware_model.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

q_aware_model.summary()


Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer_1 (QuantizeL  (None, 28, 28)           3         
 ayer)                                                           
                                                                 
 quant_reshape_1 (QuantizeWr  (None, 28, 28, 1)        1         
 apperV2)                                                        
                                                                 
 quant_conv2d_1 (QuantizeWra  (None, 26, 26, 12)       147       
 pperV2)                                                         
                                                                 
 quant_max_pooling2d_1 (Quan  (None, 13, 13, 12)       1         
 tizeWrapperV2)                                                  
                                                                 
 quant_flatten_1 (QuantizeWr  (None, 2028)            

In [7]:
train_images_subset = train_images[0:1000]      # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1)



<keras.callbacks.History at 0x19d9ea676d0>

In [8]:
_, baseline_model_accuracy = model.evaluate(test_images,test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(test_images,test_labels,verbose=0)

print("Baseline test accuracy: ", baseline_model_accuracy)
print("Quantized test accuracy: ", q_aware_model_accuracy)

Baseline test accuracy:  0.9641000032424927
Quantized test accuracy:  0.9646000266075134


In [9]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

# After this you have an actually quantized model with int8 weights and uint8 activations




INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmp0d0im6zh\assets


INFO:tensorflow:Assets written to: C:\Users\sandr\AppData\Local\Temp\tmp0d0im6zh\assets


# Persistence of accuracy from TF to TFLite

In [10]:
import numpy as np

def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on over y image in the test dataset
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print("Evaluated on {n} results so far.".format(n=i))
        
        # Pre-processing: add batch dimension and convert to float32 to match with the model's input data format
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest probability
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print("\n")

    # Compare prediction result with ground truth labels to calculate accuracy
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy


In [11]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)
print("Quantized TFLite test_accuracy: ", test_accuracy)
print("Quant TF test accuracy: ", q_aware_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quantized TFLite test_accuracy:  0.9646
Quant TF test accuracy:  0.9646000266075134
