# Simple Quantization Example

In this example, we demonstrate a simple quantization workflow using `femtoflow`. Quantization is a technique used to reduce the memory footprint and improve the computational efficiency of a model by converting the weights and activations from floating-point representation to fixed-point representation.

## Considerations

- Quantization may introduce a slight drop in model accuracy due to the reduced precision. The representative dataset and calibration process help minimize the accuracy loss.

- Different quantization modes (e.g., `8x8`, `8x16`) offer trade-offs between model size, accuracy, and performance. Choosing the right quantization mode depends on the specific use case and deployment requirements.



## Installation

In [None]:
# ! pip install femtoflow --quiet

## Imports

In [None]:
import tempfile
import os

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
from tensorflow import keras

from femtoflow.quantization.quantize_tflite import TFLiteModelWrapper

In [None]:
import warnings 
warnings.filterwarnings('ignore')

## MNIST Dataset Download

We will use the MNIST dataset for this example. The dataset consists of grayscale images of handwritten digits (0-9) and corresponding labels. Each image has a size of 28x28 pixels.


In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

BATCH_SIZE = 1024
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(BATCH_SIZE, drop_remainder=True)
test_dataset  = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE, drop_remainder=True)

## Model Training

### Model Definition

We will define a simple convolutional neural network (CNN) model for digit classification. The model architecture is as follows:

1. Input layer with shape (28, 28) to accept grayscale images of size 28x28 pixels.
2. Reshape layer to reshape the input images into a 4D tensor of shape (batch_size, 28, 28, 1).
3. Conv2D layer with 12 filters, a kernel size of (3, 3), and ReLU activation.
4. MaxPooling2D layer with a pool size of (2, 2) for downsampling.
5. Flatten layer to convert the 2D feature maps into a 1D feature vector.
6. Dense layer with 100 neurons.
7. Dense layer with 50 neurons.
8. Output Dense layer with 10 neurons (one for each class label) and no activation (logits).

**Note:** Although the Femtosense SPU does not currently support `Conv2D` layers, TFLite files containing Conv2D models can still be generated using the `femtoflow` tool. However, be aware that attempting to compile these TFLite files with [femtocrux](https://femtocrux.femtosense.ai/en/latest/) will result in an error.


In [None]:
# Define the model architecture.
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(100),
  tf.keras.layers.Dense(50),
  tf.keras.layers.Dense(10)
])

### Define Training-Related Parameters

Before we proceed with training the model, we need to define a few training-related parameters:

1. `optimizer`: The optimization algorithm used to update the model weights. We will use the Adam optimizer.
2. `loss_fn`: The loss function used to measure the difference between the predicted and true labels. We will use the Sparse Categorical Crossentropy loss, which is suitable for multi-class classification tasks.
3. `metrics`: The evaluation metric used to assess the performance of the model. We will use classification accuracy.
4. `epochs`: The number of times the training process iterates over the entire dataset. We will set it to 2 for demonstration purposes.
5. `validation_split`: The fraction of the training dataset reserved for validation. We will use a 10% validation split.
6. `batch_size`: The number of samples per gradient update. We will use a batch size of 512.



In [None]:
optimizer = 'adam'
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ['accuracy']
epochs = 2
validation_split = 0.1
batch_size = 512

### Train the Digit Classification Model

Once the model architecture and training-related parameters are defined, we can proceed to train the digit classification model. To do this, we use the `compile` method to configure the model for training by specifying the optimizer, loss function, and evaluation metric. We then use the `fit` method to start the training process with the specified batch size, number of epochs, and validation split.


In [None]:
# Train the digit classification model
model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

model.fit(
  train_images,
  train_labels,
  batch_size=batch_size,
  epochs=epochs,
  validation_split=validation_split,
)

In [None]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

## Quantize the Model Using TFLite

Quantization is a technique used to reduce the memory footprint and improve the computational efficiency of a model by converting the weights and activations from floating-point representation to fixed-point representation. In this section, we will quantize the trained model using TensorFlow Lite (TFLite).

### Quantize Using `TFLiteModelWrapper()` Class

To perform quantization, we will use the `TFLiteModelWrapper()` class. This class provides a convenient interface for converting a TensorFlow model to a quantized TFLite model. 

We first define a `representative_data_gen` function that generates the representative dataset for quantization calibration. This dataset is generated using a subset of the training images and reflects the typical input data distribution.

We then set the quantization mode to either `'8x16'` or `'8x8'`. The mode `'8x16'` indicates 8-bit quantization for weights and 16-bit quantization for activations, while `'8x8'` indicates 8-bit quantization for both weights and activations. We provide the model, representative dataset, and quantization mode to the `TFLiteModelWrapper()` class and specify the save path for the quantized TFLite model (`tflite_save_path`).


In [None]:
batch_size = 1
num_samples=100
input_name = model.input_names[0]
output_name = model.output_names[0]
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(batch_size).take(num_samples):
    # Model has only one input so each data point has one element.
    yield {input_name: tf.cast(input_value, dtype=tf.float32)}

tflite_save_path = 'tflite_dense.tflite'
quantize_mode = '8x16' # Or '8x8'
model_tflite = TFLiteModelWrapper(quantize_mode=quantize_mode,
                                  model=model,
                                  representative_dataset=representative_data_gen,
                                  tflite_save_path=tflite_save_path)


## Compare Baseline/Quantized Accuracy

To evaluate the performance of the quantized model, we will compare its accuracy to the accuracy of the original (baseline) model. We define a helper function `_accuracy_mnist_` that calculates the classification accuracy for the given model on the MNIST test dataset. 

We then use this function to calculate the accuracy of both the quantized TFLite model (`model_tflite`) and the original TensorFlow model (`model`). The results will help us understand the impact of quantization on the model's performance.


In [None]:

def _accuracy_mnist_(model, test_dataset, output_name='output_0', input_name=input_name):
  num_correct = 0
  num_samples = 0
  for batch_id, (x_batch, y_batch) in enumerate(test_dataset):
    y_pred = model({input_name: x_batch}) #[output_name]
    if isinstance(y_pred, dict):
      y_pred = y_pred[output_name]
    num_samples += len(y_batch)
    num_correct += sum(np.argmax(y_pred, axis=1)== y_batch.numpy())

  return num_correct/num_samples


acc = _accuracy_mnist_(model_tflite, test_dataset)
print("TFLite Quantized Accuracy", acc)

acc_orig = _accuracy_mnist_(model, test_dataset)
print("Baseline Model Accuracy", acc_orig)