# Pruning and Quantization on MNIST Digit Recognition

In this notebook, we demonstrate the application of pruning and quantization techniques to optimize a neural network model for the MNIST Digit Recognition task. Our objective is to create an efficient model with a reduced memory footprint while maintaining high accuracy.

## Overview
- **Task:** MNIST Digit Recognition
- **Model:** Neural Network with Dense layers
- **Layers to Prune:** Dense layers
- **Pencil Size:** 4 (for pruning)

## Tutorial Flow

1. **Pruning Using `model.fit(_)` method:** Prune a Tensorflow model using `model.fit()` method. The goal is to sparsify the model while maintaining its performance.

4. **TFLite Conversion:** The sparsified-model, equipped with the sparse weights, is then converted to TFLite format for deployment on edge devices.

5. Compiling TFLite model using [femtocrux](https://femtocrux.femtosense.ai/en/latest/)

6. Generating Program Files from the Compiled Model using [femtodriverpub](https://github.com/femtosense/femtodriverpub) , and loading them onto SPU. 

## Installation


In [None]:
# ! pip install femtoflow --quiet

## Imports

In [None]:
import tempfile
import os

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
from tensorflow import keras

from femtoflow.quantization.quantize_tflite import TFLiteModelWrapper
from femtoflow.sparsity.prune import PruneHelper
from femtoflow.utils.plot import plot_prune_mask
from femtoflow.utils.metrics import calculate_sparsity, get_gzipped_model_size

In [None]:
import warnings 
warnings.filterwarnings('ignore')

## Load MNIST Dataset

In this section, we load the MNIST dataset, which is a popular dataset for handwritten digit recognition. It consists of grayscale images of handwritten digits, each of size 28x28 pixels, along with corresponding labels indicating the digit (0-9) represented by each image.

### Data Preparation

1. **Load MNIST Data:** We use the Keras built-in `mnist` module to load the dataset, which is split into training and testing sets.

2. **Normalize Data:** We normalize the input images by scaling the pixel values to the range [0, 1]. This is achieved by dividing each pixel value by 255 (the maximum pixel value in grayscale images).

3. **Create TensorFlow Datasets:** We use the `tf.data.Dataset.from_tensor_slices` method to create TensorFlow datasets for both the training and testing sets. We configure the datasets to provide data in batches using the `batch` method. The batch size is set to 1024, and `drop_remainder=True` ensures that any incomplete batches are discarded.

The resulting `train_dataset` and `test_dataset` are ready for use in training and evaluating the model, respectively.


In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

num_samples_train, hx, hy = train_images.shape
train_images = tf.reshape(train_images, (num_samples_train, hx * hy))

num_samples_test, hx, hy = test_images.shape
test_images = tf.reshape(test_images, (num_samples_test, hx * hy))

BATCH_SIZE = 1024
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(BATCH_SIZE, drop_remainder=True)
test_dataset  = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE, drop_remainder=True)

## Model Training

In this section, we define the architecture of the neural network model that we will train and optimize for the MNIST digit recognition task. The model consists of convolutional, pooling, and dense layers.

### Define Model Architecture

1. **Input Layer:** The input layer is configured to accept Flattened Grayscale images of size 784x1 pixels.

6. **Dense Layers:** A sequence of dense (fully connected) layers with different numbers of units. The final dense layer has 10 units corresponding to the 10 possible digits (0-9).

The defined model architecture is ready for training on the MNIST dataset.


In [None]:
# Define the model architecture.
# Define the model architecture.
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(784, )),
  tf.keras.layers.Dense(200),
  tf.keras.layers.Dense(100),
  tf.keras.layers.Dense(50),
  tf.keras.layers.Dense(10),
])


### Define Training-Related Parameters

Before training the model, we need to define various parameters and configurations that will be used during the training process:

1. **Optimizer:** We choose the `'adam'` optimizer, which is an adaptive optimization algorithm commonly used for training neural networks.

2. **Loss Function:** We use the `SparseCategoricalCrossentropy` loss function with `from_logits=True`. This loss function is suitable for multi-class classification tasks, such as digit recognition. The `from_logits=True` option indicates that the model's output is not yet passed through the softmax activation function.

3. **Metrics:** We monitor the `'accuracy'` metric during training to evaluate the performance of the model.

4. **Epochs:** The number of epochs is set to `2`. An epoch is a complete pass through the entire training dataset.

5. **Validation Split:** We specify a validation split of `0.1`, which means that 10% of the training data will be used as a validation set to evaluate the model's performance during training.

6. **Batch Size:** We set the batch size to `512`, which determines the number of samples used in each update of the model weights.

With these parameters defined, we can proceed with training the model on the MNIST dataset.


In [None]:
optimizer = 'adam'
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ['accuracy']
epochs = 2
validation_split = 0.1
batch_size = 512


### Train the Digit Classification Model

With the model architecture defined and the training-related parameters configured, we can proceed with training the digit classification model on the MNIST dataset.

1. **Compile the Model:** We use the `compile` method to configure the model for training. We pass the optimizer, loss function, and evaluation metrics as arguments.

2. **Fit the Model:** We use the `fit` method to train the model on the training data. We provide the training images and corresponding labels, the batch size, the number of epochs, and the validation split. The `fit` method will perform forward and backward propagation, update the model weights, and monitor the training and validation accuracy over the specified number of epochs.

During training, the model's performance is evaluated on both the training and validation sets at the end of each epoch. The training accuracy indicates how well the model is fitting the training data, while the validation accuracy indicates how well the model generalizes to unseen data.


In [None]:
# Train the digit classification model
model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

model.fit(
  train_images,
  train_labels,
  batch_size=batch_size,
  epochs=epochs,
  validation_split=validation_split,
)

### Get Baseline Model Accuracy

After training the digit classification model, it is important to evaluate its performance on the test dataset. The test dataset consists of images and labels that were not used during training, so it provides an unbiased evaluation of the model's ability to generalize to new data.

We use the `evaluate` method to calculate the test accuracy of the model. The method takes the test images and corresponding labels as input, performs forward propagation through the model, and compares the predicted labels with the true labels to calculate the accuracy.


In [None]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

## Prune the Model

Pruning is a model optimization technique that reduces the number of parameters in the model by setting some of them to zero. This results in a sparse model that requires less memory and computational resources. In this section, we will apply structured pruning to the digit classification model to obtain a sparser version of the model.

### Define the `prune_helper = PruneHelper()` Class

To facilitate the pruning process, we define an instance of the `PruneHelper` class. This class provides a set of utility functions for applying structured pruning to specific layers of the model.

1. **Pencil Size:** The pencil size specifies the granularity of the pruning mask applied to the model weights. Supported Pencil Sizes are `8` and `4`.

2. **Pencil Pooling Type:** The pencil pooling type determines the pooling operation used in the pruning mask calculation. The options are 'AVG' (average pooling) and 'MAX' (max pooling).

3. **Prune Scheduler:** The prune scheduler controls the sparsity schedule over the training epochs. The options are 'linear' (linear increase in sparsity), 'constant' (constant sparsity), and 'poly_decay' (polynomial decay in sparsity).

We configure these parameters and create an instance of the `PruneHelper` class.


In [None]:
pencil_size = 4
pencil_pooling_type = 'AVG'
prune_scheduler = 'linear'  # 'constant' # 'poly_decay'
prune_helper = PruneHelper(pencil_size=pencil_size,
                             pencil_pooling_type=pencil_pooling_type,
                             prune_scheduler=prune_scheduler)
                             

### Apply Pruning Wrappers to the Model

The `PruneHelper` class applies pruning wrappers to the layers of the model that we wish to sparsify. The output is a new model, `model_to_prune`, with pruning wrappers applied to the specified layers.

To configure the pruning process, we define the following additional parameters:

1. **Layers to Prune:** A list specifying the types of layers we want to prune. In this case, we choose to prune only the Dense layers.

2. **Initial Sparsity:** The initial sparsity level of the model at the beginning of the pruning process.

3. **Final Sparsity:** The target sparsity level of the model at the end of the pruning process.

4. **Begin Step:** The training step at which pruning begins.

5. **End Step:** The training step at which pruning ends. This is calculated based on the total number of training steps per epoch and the total number of epochs.

6. **Prune Frequency:** The frequency (in number of training steps) at which the pruning mask is updated.

7. **Power:** The exponent used in the polynomial decay schedule for sparsity (only applicable if the prune scheduler is set to 'poly_decay').

We pass these parameters to the `PruneHelper` instance to obtain the model with pruning wrappers applied.

The `model_to_prune` now has pruning wrappers applied to the Dense layers, and it is ready for training with sparsity. The goal is to achieve a high level of sparsity while preserving the model's accuracy.



In [None]:
"""
Define additional parameters for pruning
"""
layers_to_prune = [tf.keras.layers.Dense] # Layers we want to prune 
initial_sparsity = 0.2
final_sparsity = 0.6
begin_step = 0
end_step = len(train_dataset)*epochs # Let function implicitly find end_step
prune_frequency = 100
power = 3

model_to_prune = prune_helper(model=model,
                                layers_to_prune=layers_to_prune,
                                initial_sparsity=initial_sparsity,
                                final_sparsity=final_sparsity,
                                begin_step=begin_step,
                                end_step=end_step,
                                prune_frequency=prune_frequency,
                                power=power)

In [None]:
model_to_prune.layers

### Train the Model with Sparsity

After applying the pruning wrappers to the model, we proceed with training the model to introduce sparsity into the weights. During training, the sparsity level of the weights is gradually increased according to the pruning schedule defined earlier.

To achieve this, we use the `tfmot.sparsity.keras.UpdatePruningStep()` callback as part of the training process. This callback is responsible for updating the pruning mask at regular intervals, resulting in a gradual increase in sparsity.

We compile the model and use the `.fit()` function to start training with sparsity. The `UpdatePruningStep()` callback is included in the list of callbacks for the `.fit()` function.

During training, the pruning mask is updated based on the prune_frequency parameter, and the sparsity level of the weights is progressively increased. At the end of training, we obtain a pruned model with sparse weights while aiming to maintain a similar level of accuracy as the original model.



In [None]:
model_to_prune.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
model_to_prune.fit(train_dataset,
                   epochs=epochs, 
                   callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])

### Evaluate the Pruned Model

After training the model with sparsity, it is important to evaluate the accuracy of the pruned model to ensure that the pruning process has not adversely impacted the model's performance. To do this, we use the `.evaluate()` function on the pruned model and test it against the test dataset.

The `prune_accuracy` represents the accuracy of the pruned model on the test dataset. Ideally, the pruned model should have an accuracy close to that of the original baseline model while benefiting from the reduced model size and sparsity.

Comparing the accuracy of the pruned model (`prune_accuracy`) with the baseline model accuracy (`baseline_model_accuracy`) allows us to assess the impact of pruning on the model's performance. A successful pruning process should achieve a good balance between accuracy and sparsity.

In [None]:
_, prune_accuracy = model_to_prune.evaluate(
   test_images, test_labels, verbose=0)

print('Pruned test accuracy:', prune_accuracy)

### Strip Pruning Wrapper to Obtain the Final Pruned Model

The pruned model (`model_to_prune`) contains pruning wrappers applied to the layers we want to sparsify. These wrappers are used to introduce and manage sparsity during the training process. After pruning is complete, we can remove these wrappers to obtain the final pruned model with the sparse weights (i.e., the final pruning mask applied to the weights).

To remove the pruning wrappers, we use the `tfmot.sparsity.keras.strip_pruning()` function. This function takes the pruned model as input and returns the pruned model with the sparse weights, but without the pruning wrappers.


In [None]:
model_pruned_stripped = tfmot.sparsity.keras.strip_pruning(model_to_prune)


### Visualize Pruned Weights

After pruning, it is helpful to visualize the pruned weights to observe the sparsity pattern achieved by the pruning process. We can plot the pruned weights as a binary mask, where white pixels represent non-zero (preserved) weights and black pixels represent zero (pruned) weights.

To do this, we use the `plot_prune_mask` function, which takes the pruned weight matrix and some visualization parameters as inputs. We iterate over each pruned layer in the model, extract its weight matrix, and plot the binary mask.


In [None]:
trainable_weights_dict = {weight.name: weight.numpy() for weight in model_pruned_stripped.trainable_weights}
trainable_weights_dict.keys()

In [None]:

base_path = 'mnist_dense_pencil_4'
os.makedirs(base_path, exist_ok=True)
max_len_axis = 24
for layer_name, layer in trainable_weights_dict.items():
    title = f"{layer_name}-shape-opxip-{layer.T.shape}"
    save_path = f"{base_path}/{layer_name.replace('/', '-')}-prune-mask.png"
    plot_prune_mask(data=layer, axis_stride=pencil_size, max_xlen=max_len_axis, max_ylen=max_len_axis, title=title, save_path=save_path)

### Calculate Sparse Metrics

In [None]:
sparsity_dict = {}
for layer_name, layer_weight in trainable_weights_dict.items():
    sparsity_dict[layer_name] = calculate_sparsity(layer_weight)
print('sparsity_dict', sparsity_dict)

## Quantize the Pruned Model using TFLite

After pruning, the next step is to quantize the pruned model to further reduce its memory footprint and improve its computational efficiency. Quantization involves converting the weights and activations from floating-point representation to fixed-point representation. In this section, we will quantize the pruned model using TensorFlow Lite (TFLite).

### Quantize using `TFLiteModelWrapper()` class

To perform quantization, we will use the `TFLiteModelWrapper()` class. This class provides a convenient interface for converting a TensorFlow model to a quantized TFLite model. We first define a `representative_data_gen` function that generates representative data for quantization calibration. This data should reflect the typical input distribution that the model will encounter during inference.

We can choose either `'8x8'` quantization mode (Int8 weights and Int8 activations) or `'8x16'` quantization mode (Int8 weights and Int16 activations). We provide the pruned model, representative dataset, and quantization mode to the `TFLiteModelWrapper()` class and specify the save path for the quantized TFLite model (`tflite_save_path`).

The `model_tflite` is the quantized TFLite model that is optimized for deployment on resource-constrained devices. It benefits from both the sparsity achieved during pruning and the reduced memory footprint achieved through quantization.


In [None]:
batch_size = 1
num_samples=100
input_name = model_pruned_stripped.input_names[0]
output_name = model_pruned_stripped.output_names[0]
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(batch_size).take(num_samples):
    # Model has only one input so each data point has one element.
    yield {input_name: tf.cast(input_value, dtype=tf.float32)}

tflite_save_path = 'tflite_dense.tflite'
quantize_mode = '8x16' # or '8x8'
model_tflite = TFLiteModelWrapper(quantize_mode=quantize_mode,
                                  model=model_pruned_stripped,
                                  representative_dataset=representative_data_gen,
                                  tflite_save_path=tflite_save_path)


### Check Performance of TFLite Model

After quantizing the pruned model, we want to evaluate its performance on the test dataset to ensure that the quantization process did not adversely affect the model's accuracy. We will define a helper function `_accuracy_mnist_` that calculates the classification accuracy for a given model and dataset.

The `_accuracy_mnist_` function takes the model, test dataset, and optional names of the output and input tensors as arguments. It performs inference on each batch of the test dataset and compares the model's predictions with the ground truth labels to calculate the overall accuracy.

We will then calculate the accuracy for the TFLite pruned and quantized model (`model_tflite`), the original baseline model (`model`), and the pruned model without quantization (`model_pruned_stripped`). Comparing these accuracies will provide insight into the effectiveness of the pruning and quantization processes.


In [None]:

def _accuracy_mnist_(model, test_dataset, output_name='output_0', input_name=input_name):
  num_correct = 0
  num_samples = 0
  for batch_id, (x_batch, y_batch) in enumerate(test_dataset):
    y_pred = model({input_name: x_batch}) #[output_name]
    if isinstance(y_pred, dict):
      y_pred = y_pred[output_name]
    num_samples += len(y_batch)
    num_correct += sum(np.argmax(y_pred, axis=1)== y_batch.numpy())

  return num_correct/num_samples


acc = _accuracy_mnist_(model_tflite, test_dataset)
print("TFLite Pruned+Quantized Accuracy", acc)

acc_orig = _accuracy_mnist_(model, test_dataset)
print("Baseline Model Accuracy", acc_orig)

acc_model_pruned_stripped = _accuracy_mnist_(model_pruned_stripped, test_dataset)
print("Pruend Model Accuracy", acc)

### Check Model Sizes

After completing the pruning and quantization processes, it is important to check the model sizes to confirm that the optimizations have effectively reduced the memory footprint of the models. We will compare the sizes of three models: the original baseline model, the pruned model without quantization, and the TFLite pruned and quantized model.

To calculate the sizes, we will save each model to a temporary file and then use gzip compression to measure the compressed size. This will give us an estimate of the storage space required for each model.


In [None]:
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_pruned_stripped, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

pruned_tflite_file = tflite_save_path
print("TFLite File was generated at:", pruned_tflite_file)

In [None]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

## Compiling TFLite model using [femtocrux](https://femtocrux.femtosense.ai/en/latest/) and generating Memory Image BitFile

Next, we will compile the generated TFLite model with Femtocrux. Compiling the model using Femtocrux is a necessary step in deploying the Tensorflow model on Femtosense's SPU.

We will need to have Docker installed and instantiate a CompilerClient. This will allow us to make API calls to the Femtosense's compiler. We can call the `compile` method of femtocrux to produce the `Memory Image bitfile`. 

In [None]:
from femtocrux import CompilerClient, TFLiteModel
client = CompilerClient()

In [None]:
flatbuffer = model_tflite.instance.flatbuffer
signature_name = model_tflite.instance.signature_name

In [None]:
bitstream = client.compile(    
    TFLiteModel(flatbuffer=flatbuffer, signature_name=signature_name)
)
# Write to a file for later use
with open('my_bitfile.zip', 'wb') as f: 
    f.write(bitstream)

## Generating Program Files, and loading onto SPU. 
The memory image bitfile zip can then be converted to Program Files using [femtodriverpub](https://github.com/femtosense/femtodriverpub).

Once generated, these Program Files can be transferred to an SD card, which can then be inserted into Femtosense's SPU. 


#### Install `Femtodriverpub`
##### Step One - Clone femtodriverpub from Github and Install it


In [None]:
! git clone https://github.com/femtosense/femtodriverpub.git; cd femtodriverpub; pip install -e .

#### Unzip the `my_bitfile.zip` zip folder


In [None]:
! rm -rf 'my_bitfile'
! unzip 'my_bitfile.zip' -d 'my_bitfile'

#### Generate the Program Files to load onto SPU

In [None]:
! python femtodriverpub/femtodriverpub/run/sd_from_femtocrux.py 'my_bitfile'

#### The Program Files should be generated at `apb_records` folder


In [None]:
! ls 'apb_records'

The contents inside `apb_records` can be loaded onto a SD card, which can then be inserted onto the SPU! Congratulations, your model is now ready to deploy on the SPU!
