# Simple Pruning Example (with Custom Training Loop)

In this section, we will demonstrate a simple example of pruning a neural network using a custom training loop. The example will cover how to apply structured pruning to a neural network model using the `PruneHelper` class from `femtoflow`. We will then train the pruned model using a custom training loop and evaluate its performance. This example aims to provide an understanding of how to implement structured pruning in a custom training pipeline.


## Installation

In [None]:
# ! pip install femtoflow --quiet

## Imports

In [None]:
import tempfile
import os

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
from tensorflow import keras

from femtoflow.sparsity.prune import PruneHelper
from femtoflow.utils.plot import plot_prune_mask

In [None]:
import warnings 
warnings.filterwarnings('ignore')

## MNIST Dataset Download

In this example, we will use the MNIST dataset, which is a widely used dataset for handwritten digit recognition. The dataset consists of 60,000 training images and 10,000 test images, each of which is a grayscale image with a size of 28x28 pixels. Each image is labeled with the corresponding digit (0-9) that it represents.

We will download the MNIST dataset, normalize the images so that each pixel value is between 0 and 1, and then prepare the dataset for training and evaluation. We will create TensorFlow Datasets (`tf.data.Dataset`) for the training and test sets, and batch the data using a batch size of 1024.


In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

BATCH_SIZE = 1024
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(BATCH_SIZE, drop_remainder=True)
test_dataset  = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE, drop_remainder=True)

## Model Training

In this section, we will define and train a convolutional neural network (CNN) for digit classification using the MNIST dataset. The model architecture consists of an input layer, a reshape layer, a Conv2D layer, a MaxPooling2D layer, a Flatten layer, and three Dense layers.

### Model Definition

We will use Keras Sequential API to define our model architecture. The input layer accepts grayscale images of size 28x28 pixels. The Conv2D layer applies 12 filters with a kernel size of 3x3 and ReLU activation function. The MaxPooling2D layer reduces the spatial dimensions by taking the maximum value from each 2x2 window. The Flatten layer flattens the 3D output into a 1D array. The final three Dense layers contain 100, 50, and 10 units, respectively.


In [None]:
# Define the model architecture.
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(100),
  tf.keras.layers.Dense(50),
  tf.keras.layers.Dense(10)
])

### Define Training Related Params

Before we start training the model, we need to define some training-related parameters such as the optimizer, loss function, evaluation metrics, number of epochs, validation split, and batch size.


In [None]:
optimizer = 'adam'
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ['accuracy']
epochs = 2
validation_split = 0.1
batch_size = 512

### Train the Digit Classification Model

After defining the model architecture and training-related parameters, we are ready to train the digit classification model. We will compile the model using the optimizer, loss function, and evaluation metrics defined earlier. Then, we will start the training process using the `fit` method.


In [None]:
# Train the digit classification model
model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

model.fit(
  train_images,
  train_labels,
  batch_size=batch_size,
  epochs=epochs,
  validation_split=validation_split,
)

In [None]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

## Prune the Model

After training the model, we will perform structured pruning to sparsify the model's weights. For this purpose, we will use the `PruneHelper` class, which provides utilities for applying structured pruning to the model.

### Define the `prune_helper = PruneHelper()` Class

We first define the `PruneHelper` class by specifying various parameters related to the pruning process, such as `pencil_size`, `pencil_pooling_type`, and `prune_scheduler`.


In [None]:
pencil_size = 4
pencil_pooling_type = 'AVG'
prune_scheduler = 'linear'  # 'constant' # 'poly_decay'
prune_helper = PruneHelper(pencil_size=pencil_size,
                           pencil_pooling_type=pencil_pooling_type,
                           prune_scheduler=prune_scheduler,
                           min_parameter_thresh=0)
                             

### Apply Pruning to the Model

To apply structured pruning to the model, we will call the `PruneHelper` instance with additional parameters related to the pruning process. The `PruneHelper` class will apply pruning masks to the specified layers of the model that we want to sparsify.

#### Define additional parameters for pruning
- `layers_to_prune`: A list of Keras layer classes that we want to prune. 
  In this example, we want to prune only the dense layers, so we specify `[tf.keras.layers.Dense]`.
- `initial_sparsity`: The initial sparsity level (fraction of weights to be set to zero) 
  at the start of the pruning process. We set it to `0.2`.
- `final_sparsity`: The final sparsity level (fraction of weights to be set to zero) 
  at the end of the pruning process. We set it to `0.6`.
- `begin_step`: The step at which to start pruning. We set it to `0`.
- `end_step`: The step at which to end pruning. We calculate it as the total number 
  of training steps, which is `len(train_dataset) * epochs`.
- `prune_frequency`: The frequency (in number of steps) at which to update the pruning mask. 
  We set it to `100`.
- `power`: The exponent for polynomial decay of the sparsity level. This parameter is used 
  when the `prune_scheduler` is set to `'poly_decay'`. We set it to `3`.
- `force_skip_layers`: A list of layer names to exclude from pruning. In this example, we want 
  to exclude a specific dense layer with the name `'dense_1/kernel:0'` from the pruning process, 
  so we specify `['dense_1/kernel:0']`.


In [None]:
"""
Define additional parameters for pruning
"""
layers_to_prune = [tf.keras.layers.Dense] # Layers we want to prune 
initial_sparsity = 0.2
final_sparsity = 0.6
begin_step = 0
end_step = len(train_dataset)*epochs # Let function implicitly find end_step
prune_frequency = 100
power = 3

model_to_prune = prune_helper(model=model,
                            layers_to_prune=layers_to_prune,
                            initial_sparsity=initial_sparsity,
                            final_sparsity=final_sparsity,
                            begin_step=begin_step,
                            end_step=end_step,
                            prune_frequency=prune_frequency,
                            power=power,
                            force_skip_layers=['dense_1/kernel:0'])

#### Note how Pruning Wrappers are applied to layers to be pruned.

In [None]:
model_to_prune.layers

### Perform Training-With-Sparsity on `model_to_prune` with `tfmot.sparsity.keras.UpdatePruningStep()` applied in a Custom Training Loop

In this section, we will define a custom training loop and perform training while inducing sparsity to the `model_to_prune`. To achieve this, we will use instances of `tfmot.sparsity.keras.UpdatePruningStep()` and call specific methods (`set_model`, `on_train_begin`, `on_train_batch_begin`, `on_epoch_end`) at specific points (Marked **1-4** in `custom_sparsity_train_loop` below) in the training loop, to successfully induce sparsity.

Reference: https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md.


In [None]:

def custom_sparsity_train_loop(model_to_prune, 
                                optimizer, 
                                metrics, 
                                loss_fn, 
                                train_dataset, 
                                val_dataset=None,
                                num_epochs=2):
    """
    Custom Training Loop, with instances of tfmot.sparsity.keras.UpdatePruningStep() 
    called in specific points of the training loop (marked 1-4) to induce sparsity.
    Reference: https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md
    """
    model_to_prune.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

    """
    1) Define prune_step and attach Model to prune_step callback
    """
    prune_step = tfmot.sparsity.keras.UpdatePruningStep()
    prune_step.set_model(model_to_prune)

    """
    2) prune_step.on_train_begin() call
    """
    prune_step.on_train_begin() # call bac
    for epoch in range(0, num_epochs):
        print(f"Processing epoch {epoch}")
        for batch_id, (x_batch, y_batch) in enumerate(train_dataset):
            
            """
            3) prune_step.on_train_batch_begin() call
            """
            prune_step.on_train_batch_begin(batch=-1) # run pruning callback

            model_to_prune.train_on_batch(x_batch, y_batch)

        """
        4) prune_step.on_epoch_end() call
        """
        prune_step.on_epoch_end(batch=-1) # run pruning callback

    if val_dataset:
        _, val_acc = model_to_prune.evaluate(val_dataset, verbose=0)
        print(f"Validation Accuracy Epoch {epoch}: {val_acc}")

    return model_to_prune

In [None]:
model_to_prune = custom_sparsity_train_loop(model_to_prune=model_to_prune,
                                            optimizer=optimizer,
                                            metrics=metrics, 
                                            loss_fn=loss_fn, 
                                            train_dataset=train_dataset, 
                                            val_dataset=test_dataset,
                                            num_epochs=epochs)

In [None]:
_, prune_accuracy = model_to_prune.evaluate(
   test_images, test_labels, verbose=0)

print('Pruned test accuracy:', prune_accuracy)

### Apply `tfmot.sparsity.keras.strip_pruning()` to Remove Sparse Layer Wrappers and Get the Model with Sparse Weights (Prune Masks Applied)

After pruning, the pruned model (`model_to_prune`) contains a `tfmot.sparsity.keras.prune_low_magnitude()` wrapper around the layers that were pruned. To obtain the final pruned model with sparse weights, we need to remove this wrapper.

The `strip_pruning()` function is used for this purpose. It removes the pruning wrapper and returns the pruned model with the final pruning mask applied to the weights, resulting in the desired sparse weights.


In [None]:
model_pruned_stripped = tfmot.sparsity.keras.strip_pruning(model_to_prune)

In [None]:
# Notice how the Pruning Wrappers have been removed again!
model_pruned_stripped.layers

## Visualize Pruned Weights

After the model has been pruned, it is helpful to visualize the pruned weights to examine the sparsity pattern achieved during the pruning process. This can provide insights into the impact of pruning on the model's internal structure.

In this example, we demonstrate visualizing the pruned weights of the model. We also highlight the fact that setting `force_skip_layers=['dense_1/kernel:0']` in the pruning configuration led to the exclusion of the `dense_1/kernel:0` layer from pruning, resulting in no sparsity in that layer's weights.


In [None]:
trainable_weights_dict = {weight.name: weight.numpy() for weight in model_pruned_stripped.trainable_weights}
trainable_weights_dict.keys()

In [None]:

base_path = 'prune_using_fit'
os.makedirs(base_path, exist_ok=True)
max_len_axis = 24
for layer_name, layer in trainable_weights_dict.items():
    title = f"{layer_name}-shape-opxip-{layer.T.shape}"
    save_path = f"{base_path}/{layer_name.replace('/', '-')}-prune-mask.png"
    plot_prune_mask(data=layer, axis_stride=pencil_size, max_xlen=max_len_axis, max_ylen=max_len_axis, title=title, save_path=save_path)