In [14]:
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

In [2]:
mnist = tf.keras.datasets.mnist

In [23]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


In [254]:
def setup_model():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(256, activation='relu'),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
    ])
    return model

In [255]:
def setup_pretrained_weights(epochs=5):
    model = setup_model()
    
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
,
        optimizer='adam',
        metrics=['accuracy']
    )
    
    model.fit(x_train, y_train, epochs=epochs)
    
    _, pretrained_weights = tempfile.mkstemp('.tf')
    
    model.save_weights(pretrained_weights)
    
    return pretrained_weights

In [256]:
def setup_pretrained_model():
    model = setup_model()
    pretrained_weights = setup_pretrained_weights()
    model.load_weights(pretrained_weights)
    return model

In [257]:
setup_model()
pretrained_weights = setup_pretrained_weights()

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


# Quantization

In [258]:
base_model = setup_model()
base_model.load_weights(pretrained_weights)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f34a416e1d0>

In [259]:
base_model.summary()

Model: "sequential_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_15 (Flatten)        (None, 784)               0         
                                                                 
 dense_48 (Dense)            (None, 128)               100480    
                                                                 
 dense_49 (Dense)            (None, 256)               33024     
                                                                 
 dropout_15 (Dropout)        (None, 256)               0         
                                                                 
 dense_50 (Dense)            (None, 128)               32896     
                                                                 
 dense_51 (Dense)            (None, 128)               16512     
                                                                 
 dense_52 (Dense)            (None, 10)              

In [260]:
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()

Model: "sequential_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer_6 (QuantizeL  (None, 28, 28)           3         
 ayer)                                                           
                                                                 
 quant_flatten_15 (QuantizeW  (None, 784)              1         
 rapperV2)                                                       
                                                                 
 quant_dense_48 (QuantizeWra  (None, 128)              100485    
 pperV2)                                                         
                                                                 
 quant_dense_49 (QuantizeWra  (None, 256)              33029     
 pperV2)                                                         
                                                                 
 quant_dropout_15 (QuantizeW  (None, 256)            

In [261]:
### Q_aware model requires a recompile
quant_aware_model.compile(optimizer='adam',
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])

quant_aware_model.summary()

# Note: the resulting model is quantization *aware* but not quantized

Model: "sequential_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer_6 (QuantizeL  (None, 28, 28)           3         
 ayer)                                                           
                                                                 
 quant_flatten_15 (QuantizeW  (None, 784)              1         
 rapperV2)                                                       
                                                                 
 quant_dense_48 (QuantizeWra  (None, 128)              100485    
 pperV2)                                                         
                                                                 
 quant_dense_49 (QuantizeWra  (None, 256)              33029     
 pperV2)                                                         
                                                                 
 quant_dropout_15 (QuantizeW  (None, 256)            

In [262]:
# We fine tune the model on a subset of the training data
train_images_subset = x_train[0:1000]
train_labels_subset = y_train[0:1000]

quant_aware_model.fit(train_images_subset, train_labels_subset,
                     batch_size=32, epochs=1, validation_split=0.1)



<keras.callbacks.History at 0x7f34c37099c0>

In [263]:
q_aware_model_accuracy = model.evaluate(x_test,  y_test, verbose=2)

313/313 - 1s - loss: 0.0676 - accuracy: 0.9798 - 572ms/epoch - 2ms/step


# TFLite Backend quantized model

In [264]:
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()



INFO:tensorflow:Assets written to: /tmp/tmpo9au5rvo/assets


INFO:tensorflow:Assets written to: /tmp/tmpo9au5rvo/assets
2023-11-04 21:38:54.028446: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-11-04 21:38:54.028492: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-11-04 21:38:54.028757: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpo9au5rvo
2023-11-04 21:38:54.035762: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-11-04 21:38:54.035796: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/tmpo9au5rvo
2023-11-04 21:38:54.061254: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-11-04 21:38:54.225996: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /tmp/tmpo9au5rvo
2023-11-04 21:38:54.264969: I tensorflow/cc/saved_model/loader.cc:305] SavedModel

In [265]:
def evaluate_model(interpreter):
      input_index = interpreter.get_input_details()[0]["index"]
      output_index = interpreter.get_output_details()[0]["index"]

      # Run predictions on every image in the "test" dataset.
      prediction_digits = []
      for i, test_image in enumerate(x_test):
        if i % 1000 == 0:
          print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

      print('\n')
      # Compare prediction results with ground truth labels to calculate accuracy.
      prediction_digits = np.array(prediction_digits)
      accuracy = (prediction_digits == y_test).mean()
      return accuracy

In [266]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9759
Quant TF test accuracy: [0.06757364422082901, 0.9797999858856201]


In [267]:
import os 
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(base_model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))


INFO:tensorflow:Assets written to: /tmp/tmpnc3xdhf4/assets


INFO:tensorflow:Assets written to: /tmp/tmpnc3xdhf4/assets


Float model in Mb: 0.7056655883789062
Quantized model in Mb: 0.18196868896484375


2023-11-04 21:39:11.116923: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-11-04 21:39:11.116962: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-11-04 21:39:11.117186: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpnc3xdhf4
2023-11-04 21:39:11.119377: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-11-04 21:39:11.119401: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/tmpnc3xdhf4
2023-11-04 21:39:11.125933: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-11-04 21:39:11.162101: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /tmp/tmpnc3xdhf4
2023-11-04 21:39:11.175568: I tensorflow/cc/saved_model/loader.cc:305] SavedModel load for tags { serve }; Status: success: OK. Took 58383 m

# Quantizine only particular layers

In [277]:
""" (https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide#quantize_some_layers)
While this example used the type of the layer to decide what to quantize, 
the easiest way to quantize a particular layer is to set its name property,
and look for that name in the clone_function.
"""
print([layer.name for layer in base_model.layers])

to_be_quantized = ['dense48', 'dense_49']
def quantization_wrapper(to_be_quantized):
    
    def quantize_layers(layer):
        if layer.name in to_be_quantized.copy(): 
            print(f"Layer {layer.name} will be quantized!")
            return tfmot.quantization.keras.quantize_annotate_layer(layer)

        # If not quantized: identity function
        return layer
    
    return quantize_layers

['flatten_15', 'dense_48', 'dense_49', 'dropout_15', 'dense_50', 'dense_51', 'dense_52']


In [278]:
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=quantization_wrapper(to_be_quantized)
)

Layer dense_49 will be quantized!


In [279]:
# Note that the layer we quantized AND the layer immediately before it both
# need to get quantized!
quant_aware_model2 = tfmot.quantization.keras.quantize_apply(annotated_model)
# quant_aware_model2 = tfmot.quantization.keras.quantize_model(annotated_model)

quant_aware_model2.summary()

Model: "sequential_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_15 (Flatten)        (None, 784)               0         
                                                                 
 quant_dense_48 (QuantizeWra  (None, 128)              100483    
 pperV2)                                                         
                                                                 
 quant_dense_49 (QuantizeWra  (None, 256)              33029     
 pperV2)                                                         
                                                                 
 dropout_15 (Dropout)        (None, 256)               0         
                                                                 
 dense_50 (Dense)            (None, 128)               32896     
                                                                 
 dense_51 (Dense)            (None, 128)             

### Fine-tuning the quant aware model

In [280]:
# Recompile
quant_aware_model2.compile(optimizer='adam',
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])

In [281]:
# Finetune
quant_aware_model2.fit(train_images_subset, train_labels_subset,
                     batch_size=32, epochs=1, validation_split=0.1)



<keras.callbacks.History at 0x7f34a424d0f0>

In [282]:
# Evaluate
q_aware_model2_accuracy = quant_aware_model2.evaluate(x_test, y_test, verbose=2)

313/313 - 1s - loss: 0.0956 - accuracy: 0.9771 - 625ms/epoch - 2ms/step


# Compare model sizes
Note: we convert the model to the TFLite backend (to support 8bit quantization)


In [283]:
converter2 = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model2)
converter2.optimizations = [tf.lite.Optimize.DEFAULT]

# REVIEW This
# converter2.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter2.inference_input_type = tf.int8  # or tf.uint8
# converter2.inference_output_type = tf.int8  # or tf.uint8

quantized_tflite_model2 = converter2.convert()



INFO:tensorflow:Assets written to: /tmp/tmpyy_yy8a4/assets


INFO:tensorflow:Assets written to: /tmp/tmpyy_yy8a4/assets
2023-11-04 21:41:51.989484: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-11-04 21:41:51.989521: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-11-04 21:41:51.989746: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpyy_yy8a4
2023-11-04 21:41:51.995336: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-11-04 21:41:51.995389: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/tmpyy_yy8a4
2023-11-04 21:41:52.014643: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-11-04 21:41:52.152638: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /tmp/tmpyy_yy8a4
2023-11-04 21:41:52.186359: I tensorflow/cc/saved_model/loader.cc:305] SavedModel

In [284]:
### TODO: Abstract this into a function
# Create float TFLite model
float_converter_base = tf.lite.TFLiteConverter.from_keras_model(base_model)
float_tflite_model_base = float_converter_base.convert()

# Measure sizes of models.
_, float_base_file = tempfile.mkstemp('.tflite')
_, quant2_file = tempfile.mkstemp('.tflite')

with open(quant2_file, 'wb') as f:
  f.write(quantized_tflite_model2)

with open(float_base_file, 'wb') as f:
  f.write(float_tflite_model_base)

print("Float model in Mb:", os.path.getsize(float_base_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant2_file) / float(2**20))

INFO:tensorflow:Assets written to: /tmp/tmpoeqcr7cp/assets


INFO:tensorflow:Assets written to: /tmp/tmpoeqcr7cp/assets


Float model in Mb: 0.7056655883789062
Quantized model in Mb: 0.6127700805664062


2023-11-04 21:41:56.257106: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2023-11-04 21:41:56.257146: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2023-11-04 21:41:56.257370: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpoeqcr7cp
2023-11-04 21:41:56.259804: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-11-04 21:41:56.259829: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/tmpoeqcr7cp
2023-11-04 21:41:56.267058: I tensorflow/cc/saved_model/loader.cc:229] Restoring SavedModel bundle.
2023-11-04 21:41:56.304451: I tensorflow/cc/saved_model/loader.cc:213] Running initialization op on SavedModel bundle at path: /tmp/tmpoeqcr7cp
2023-11-04 21:41:56.318629: I tensorflow/cc/saved_model/loader.cc:305] SavedModel load for tags { serve }; Status: success: OK. Took 61259 m