In [1]:

import absl, os, time
import numpy as np
import matplotlib.pyplot as plt

# Import tensorflow
import tensorflow as tf
import tensorflow.keras as keras
# Import tensorflow model optimization, used for quantization-aware training
import tensorflow_model_optimization as tfmot

# remove annoying logging
tf.get_logger().setLevel('ERROR')
absl.logging.set_verbosity(absl.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

## useful functions

In [2]:
# A function that converts a tensorflow model with LSTMs to TFLite model
def convert_LSTM_to_float_TFLite(model_tf, TFLite_target_filename, batch_size=1):
    # We need to clearly set the input signature of the Keras model
    # As of now, only one dimension can be dynamic. We must fix the batch size
    
    # create TFLite converter from teh model saved on disk
    converter = tf.lite.TFLiteConverter.from_keras_model(model_tf)
    # Set the optimization flag to use default quantization
    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
    # Use default TFLite and TF operators
    converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS,tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    # use float as input and output
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8

    def generate_representative_dataset():
            for i in range(1000):
                yield([np.float32(x_train[i]).reshape((1,28,28))])
    converter.representative_dataset = generate_representative_dataset
    
    # Convert model
    model_tflite = converter.convert()
    
    open(TFLite_target_filename, "wb").write(model_tflite)
    print(f"TFLite model size: {os.path.getsize(TFLite_target_filename)}")
    return model_tflite

In [3]:
def predict_TFLite(model, X):
    x_data = X.copy() # the function quantizes the input, so we must make a copy
    # Initialize the TFLite interpreter
    interpreter = tf.lite.Interpreter(model_content=model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    
    outputs = []
    
    # Quantize input if needed
    input_scale, input_zero_point = input_details["quantization"]
    if (input_scale, input_zero_point) != (0.0, 0):
        x_data = x_data / input_scale + input_zero_point
    x_data = x_data.astype(input_details["dtype"])
        
    
    
    for x in x_data:
        # We need to resize the input shape to fit the dynamic sequence (batch size must be equal to 1)
        interpreter.resize_tensor_input(input_details['index'], (1,)+x.shape, strict=True)
        interpreter.allocate_tensors()
        interpreter.set_tensor(input_details["index"], [x])
        interpreter.invoke()
        outputs.append(np.copy(interpreter.get_tensor(output_details["index"])))
    
    
    # Dequantize output
    outputs = np.array(outputs)
    output_scale, output_zero_point = output_details["quantization"]
    if (output_scale, output_zero_point) != (0.0, 0):
        outputs = outputs.astype(np.float32)
        outputs = (outputs - output_zero_point) * output_scale
    # todo reshape output into array for each exit
    return outputs

In [4]:
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Basic standardization
x_train = x_train / 255.0
x_test = x_test / 255.0

# TF Model - no quantization

In [5]:
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.LSTM(20),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  x_train,
  y_train,
  epochs=3,
)
_, accuracy = model.evaluate(x_test, y_test)
print(f"test accuracy: {accuracy}")

Epoch 1/3
Epoch 2/3
Epoch 3/3
test accuracy: 0.9372000098228455


# Post-Training Quantization (PTQ)

In [6]:
# Post training quantization
model_ptq = convert_LSTM_to_float_TFLite(model, "model_ptq.tflite")
x_pred = np.argmax(predict_TFLite(model_ptq, x_test),axis=-1)
print(f"Test accuracy: {np.nanmean(x_pred.flatten()==y_test.flatten())}")



TFLite model size: 15984
Test accuracy: 0.9322


# Quantization Aware Training - LSTM is quantized after training (Option A)

In [7]:
class NoOpQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
  def get_weights_and_quantizers(self, layer):
    return []
  def get_activations_and_quantizers(self, layer):
    return []
  def set_quantize_weights(self, layer, quantize_weights):
    pass
  def set_quantize_activations(self, layer, quantize_activations):
    pass
  def get_output_quantizers(self, layer):
    return []
  def get_config(self):
    return {}

def annotate_layers(layer):
    if isinstance(layer, tf.keras.layers.LSTM):
        return tfmot.quantization.keras.quantize_annotate_layer(layer, quantize_config=NoOpQuantizeConfig())
    return layer


model_qat = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.LSTM(20),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])
model_qat.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Annotate layers when copying
model_qat = tf.keras.models.clone_model(model_qat,clone_function=annotate_layers,)
# model_q.set_weights(model.get_weights())
# Specify scope if you use weird Layers (functional API)
with tfmot.quantization.keras.quantize_scope({'Multiply': tf.keras.layers.Multiply}):
    model_qat = tfmot.quantization.keras.quantize_model(model_qat)

model_qat.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_qat.fit(
  x_train,
  y_train,
  epochs=3,
)

model_qat_skip = convert_LSTM_to_float_TFLite(model_qat, "model_qat_skip.tflite")

Epoch 1/3
Epoch 2/3
Epoch 3/3




TFLite model size: 16072


In [8]:
x_pred = np.argmax(predict_TFLite(model_qat_skip, x_test),axis=-1)
print(f"Test accuracy: {np.nanmean(x_pred.flatten()==y_test.flatten())}")

Test accuracy: 0.9358


# Quantization-Aware Training (QAT) - Experimental LSTM quantization (Option B)

In [9]:
LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

class LSTMQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
      return [(layer.cell.kernel, LastValueQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False)),
              (layer.cell.recurrent_kernel, LastValueQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False)),
              ]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
      return [(layer.cell.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False)),
              (layer.cell.recurrent_activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]

    def set_quantize_weights(self, layer, quantize_weights):
      # Add this line for each item returned in `get_weights_and_quantizers`
      # , in the same order
      layer.cell.kernel = quantize_weights[0]
      layer.cell.recurrent_kernel = quantize_weights[1]

    def set_quantize_activations(self, layer, quantize_activations):
      # Add this line for each item returned in `get_activations_and_quantizers`
      # , in the same order.
      layer.cell.activation = quantize_activations[0]
      layer.cell.recurrent_activation = quantize_activations[1]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
      # return [(layer.output, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]
      return []

    def get_config(self):
      return {}



quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope


quant_aware_model = quantize_annotate_model(keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  quantize_annotate_layer(keras.layers.LSTM(20, use_bias=False), LSTMQuantizeConfig()),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
]))

# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope({'LSTMQuantizeConfig': LSTMQuantizeConfig,'LSTM': keras.layers.LSTM}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(quant_aware_model)


quant_aware_model.summary()




quant_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

quant_aware_model.fit(
  x_train,
  y_train,
  epochs=3,
)

model_qat_noskip = convert_LSTM_to_float_TFLite(quant_aware_model, "model_qat_noskip.tflite")

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer_1 (QuantizeL  (None, 28, 28)           3         
 ayer)                                                           
                                                                 
 quant_lstm_2 (QuantizeWrapp  (None, 20)               3849      
 erV2)                                                           
                                                                 
 quant_flatten_2 (QuantizeWr  (None, 20)               1         
 apperV2)                                                        
                                                                 
 quant_dense_2 (QuantizeWrap  (None, 10)               215       
 perV2)                                                          
                                                                 
Total params: 4,068
Trainable params: 4,050
Non-traina



TFLite model size: 19456


In [10]:
x_pred = np.argmax(predict_TFLite(model_qat_noskip, x_test),axis=-1)
print(f"Test accuracy: {np.nanmean(x_pred.flatten()==y_test.flatten())}")

Test accuracy: 0.9233
