# Structured Pruning of a Fully-Connected Keras Model using the Model Compression Toolkit (MCT)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb)

## Overview
This tutorial provides a step-by-step guide to training, pruning, and finetuning a Keras fully connected neural network model using the Model Compression Toolkit (MCT). We will start by building and training the model from scratch on the MNIST dataset, followed by applying structured pruning to reduce the model size.

## Summary
In this tutorial, we will cover:

1. **Training a Keras model on MNIST:** We'll begin by constructing a basic fully connected neural network and training it on the MNIST dataset. 
2. **Applying structured pruning:** We'll introduce a pruning technique to reduce model size while maintaining performance. 
3. **Finetuning the pruned model:** After pruning, we'll finetune the model to recover any lost accuracy. 
4. **Evaluating the pruned model:** We'll evaluate the pruned model’s performance and compare it to the original model.

## Setup
Install the relevant packages:

In [None]:
TF_VER = '2.14.0'
!pip install -q tensorflow~={TF_VER}

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
import tensorflow as tf
import model_compression_toolkit as mct
import numpy as np
from tensorflow.keras.datasets import mnist

## Loading and Preprocessing MNIST
Let's define the dataset loaders to retrieve the train and test parts of the MNIST dataset, including preprocessing:

In [None]:
# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the images to [0, 1] range
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0


## Creating a Fully-Connected Model
In this section, we create a simple example of a fully connected model to demonstrate the pruning process. It consists of three dense layers with 128, 64, and 10 neurons. After defining the model architecture, we compile it to prepare for training and evaluation.

In [None]:
def create_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

## Training Dense Model on MNIST
Next, we will train the dense model using the preprocessed MNIST dataset.

In [None]:
# Train and evaluate the model
model = create_model()
model.fit(train_images, train_labels, epochs=6, validation_data=(test_images, test_labels))
model.evaluate(test_images, test_labels)

## Dense Model Properties
The `model.summary()` function in Keras provides a comprehensive overview of the model's architecture, including each layer's type, output shapes, and the number of trainable parameters.

In [None]:
model.summary()

Let's break down the details from our model summary:

- **First Dense Layer:** A fully connected layer with 128 output channels and 784 input channels.
- **Second Dense Layer:** A fully connected layer with 64 output channels and 128 input channels.
- **Third Dense Layer:** The final layer with 10 neurons (matching the number of MNIST classes) and 64 input channels.

The model has a total of 109,386 parameters, requiring approximately 427.29 KB of memory.

## MCT Structured Pruning

### Target Platform Capabilities (TPC)
MCT optimizes models for dedicated hardware using Target Platform Capabilities (TPC). For more details, please refer to our [documentation](https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/target_platform_capabilities.html)). First, we'll configure the TPC to define each layer's SIMD (Single Instruction, Multiple Data) size.

In MCT, SIMD plays a key role in channel grouping, influencing the pruning process by considering channel importance within each SIMD group.

For this demonstration, we'll use the simplest structured pruning scenario with SIMD set to 1.

In [None]:
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema

simd_size = 1

def get_tpc():
    # Define the default weight attribute configuration
    default_weight_attr_config = schema.AttributeQuantizationConfig(
        weights_quantization_method=QuantizationMethod.UNIFORM,
    )

    # Define the OpQuantizationConfig
    default_config = schema.OpQuantizationConfig(
        default_weight_attr_config=default_weight_attr_config,
        attr_weights_configs_mapping={},
        activation_quantization_method=QuantizationMethod.UNIFORM,
        activation_n_bits=8,
        supported_input_activation_n_bits=8,
        enable_activation_quantization=False,
        quantization_preserving=False,
        fixed_scale=None,
        fixed_zero_point=None,
        simd_size=simd_size,
        signedness=schema.Signedness.AUTO
    )
    
    # In this tutorial, we will use the default OpQuantizationConfig for all operator sets.
    operator_set=[]

    # Create the quantization configuration options and model
    default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config]))
    tpc = schema.TargetPlatformCapabilities(default_qco=default_configuration_options,
                                            tpc_minor_version=1,
                                            tpc_patch_version=0,
                                            tpc_platform_type="custom_pruning_notebook_tpc",
                                            operator_set=tuple(operator_set))
    return tpc


### Representative Dataset
We are creating a representative dataset to guide the model pruning process. It is used to compute an importance score for each channel. This dataset is implemented as a generator that returns a list of images.

In [None]:
import random

def representative_data_gen():
  indices = random.sample(range(len(train_images)), 32)
  yield [np.stack([train_images[i] for i in indices])]

### Resource Utilization
We define a `resource_utilization` limit to constrain the memory usage of the pruned model. We'll prune our trained model to reduce its size, aiming for a 50% reduction in the memory footprint of the model's weights. Since the weights use the float32 data type (each parameter occupying 4 bytes), we calculate the memory usage by multiplying the total number of parameters by 4. By setting a target to limit the model's weight memory to around 214 KB, we aim for a 50% compression ratio.

In [None]:
# Create a ResourceUtilization object to limit the pruned model weights memory to a certain resource constraint
dense_model_memory = 427*(2**10) # Original model weights requiers ~427KB
compression_ratio = 0.5

resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_memory*compression_ratio)

### Model Pruning
We are now ready to perform the actual pruning using MCT’s `keras_pruning_experimental` function. The model will be pruned based on the defined resource utilization constraints and the previously generated representative dataset.

Each channel’s importance is measured using the [LFH (Label-Free-Hessian) method](https://arxiv.org/abs/2309.11531), which approximates the Hessian of the loss function with respect to the model’s weights.

For efficiency, we use a single score approximation. Although less precise, it significantly reduces processing time compared to multiple approximations, which offer better accuracy but at the cost of longer runtimes.

MCT’s structured pruning will target the first two dense layers, where output channel reduction can be propagated to subsequent layers by adjusting their input channels accordingly.

The output is a pruned model along with pruning information, including layer-specific pruning masks and scores.

In [None]:
num_score_approximations = 1

target_platform_cap = get_tpc()
pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(
        model=model,
        target_resource_utilization=resource_utilization,
        representative_data_gen=representative_data_gen,
        target_platform_capabilities=target_platform_cap,
        pruning_config=mct.pruning.PruningConfig(num_score_approximations=num_score_approximations)
    )

### Pruned Model Properties
As before, we can use the Keras model API to inspect the new architecture and details of the pruned model.

In [None]:
pruned_model.summary()

## Finetuning the Pruned Model
After pruning, it’s common to see a temporary drop in model accuracy due to the reduction in model complexity. Let’s demonstrate this by evaluating the pruned model and observing its initial performance before finetuning.

In [None]:
pruned_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
pruned_model.evaluate(test_images, test_labels)

To restore the model's performance, we finetune the pruned model, allowing it to adapt to its new, compressed architecture. Through this finetuning process, the model can often recover its original accuracy, and in some cases, even surpass it.

In [None]:
pruned_model.fit(train_images, train_labels, epochs=6, validation_data=(test_images, test_labels))
pruned_model.evaluate(test_images, test_labels)

## Conclusion
In this tutorial, we explored the process of structured model pruning using MCT to optimize a dense neural network. We demonstrated how to define resource constraints, apply pruning based on channel importance, and evaluate the impact on model architecture and performance. Finally, we showed how finetuning can recover the pruned model’s accuracy. This approach highlights the effectiveness of structured pruning for reducing model size while maintaining performance, making it a powerful tool for model optimization.

Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
