# Structured Pruning of a Fully-Connected Keras Model

[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb)

Welcome to this tutorial, where we will guide you through training, pruning, and retraining a fully connected Keras model. We'll begin by constructing and training a simple neural network using the Keras framework. Following this, we will introduce and apply model pruning using MCT to reduce the size of our network. Finally, we'll retrain our pruned model to recover its degraded performance due to the pruning process.


## Installing TensorFlow and Model Compression Toolkit

We start by setting up our environment by installing TensorFlow and Model Compression Toolkit and importing them.

In [None]:
TF_VER = '2.14.0'

!pip install -q tensorflow=={TF_VER}
!pip install mct-nightly


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import model_compression_toolkit as mct
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

## Loading and Preprocessing MNIST

Let's create the train and test parts of MNIST dataset including preprocessing:

In [None]:
# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the images to [0, 1] range
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0


## Creating a Fully-Connected Model

In this tutorial section, we create a simple toy example of a fully connected model to demonstrate the pruning process using MCT. It consists of three dense layers with 128, 64, and 10 neurons.

Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers offer the opportunity to reduce output channels. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.

Once our model is created, we compile it to prepare the model for training and evaluation.


In [None]:
def create_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

## Training Dense Model on MNIST

Now, we can train our model using the dataset we load and evaluate it.

In [None]:
# Train and evaluate the model
model = create_model()
model.fit(train_images, train_labels, epochs=6, validation_data=(test_images, test_labels))

model.evaluate(test_images, test_labels)

## Dense Model Properties

The model.summary() function in Keras provides a snapshot of the model's architecture, including layers, their types, output shapes, and the number of parameters.


In [None]:
model.summary()

Let's break down what we see in our model summary:

- First Dense Layer: A fully connected layer with 128 output channels and 784 input channels.

- Second Dense Layer: A fully connected layer with 64 output channels and 128 input channels.

- Third Dense Layer: The final dense layer with 10 neurons (as per the number of MNIST classes) and 64 input channels.

The total parameters amount to 109,386, which roughly requiers 427.29 KB.

## MCT Structured Pruning

### Create TPC

Firstly, we'll set up the Target Platform Capabilities (TPC) to specify each layer's SIMD (Single Instruction, Multiple Data) size.

In MCT, SIMD plays a crucial role in channel grouping, affecting the pruning decision process based on channel importance for each SIMD group of channels.

We'll use the simplest structured pruning scenario for this demonstration with SIMD=1.

In [None]:
from model_compression_toolkit.target_platform_capabilities.target_platform import Signedness
tp = mct.target_platform

simd_size = 1

def get_tpc():
    # Define the default weight attribute configuration
    default_weight_attr_config = tp.AttributeQuantizationConfig(
        weights_quantization_method=tp.QuantizationMethod.UNIFORM,
        weights_n_bits=None,
        weights_per_channel_threshold=None,
        enable_weights_quantization=None,
        lut_values_bitwidth=None
    )

    # Define the OpQuantizationConfig
    default_config = tp.OpQuantizationConfig(
        default_weight_attr_config=default_weight_attr_config,
        attr_weights_configs_mapping={},
        activation_quantization_method=tp.QuantizationMethod.UNIFORM,
        activation_n_bits=8,
        supported_input_activation_n_bits=8,
        enable_activation_quantization=None,
        quantization_preserving=None,
        fixed_scale=None,
        fixed_zero_point=None,
        simd_size=simd_size,
        signedness=Signedness.AUTO
    )

    # Create the quantization configuration options and model
    default_configuration_options = tp.QuantizationConfigOptions([default_config])
    tp_model = tp.TargetPlatformModel(default_configuration_options)

    # Return the target platform capabilities
    tpc = tp.TargetPlatformCapabilities(tp_model)
    return tpc


### Create a Representative Dataset

We are creating a representative dataset to guide our model pruning process for computing importance score for each channel:

In [None]:
import random

def representative_data_gen():
  indices = random.sample(range(len(train_images)), 32)
  yield [np.stack([train_images[i] for i in indices])]

### Create Resource Utilization constraint

We're defining a resource_utilization limit to constrain the memory usage of our pruned model.

By setting a target that limits the model's weight memory to half of its original size (around 427KB), we aim to achieve a compression ratio of 50%:

In [None]:
# Create a ResourceUtilization object to limit the pruned model weights memory to a certain resource constraint
dense_model_memory = 427*(2**10) # Original model weights requiers ~427KB
compression_ratio = 0.5

resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_memory*compression_ratio)

### Prune Model

We're ready to execute the actual pruning using MCT's keras_pruning_experimental function. The model is pruned according to our defined target Resource Utilization and using the representative dataset generated earlier.

Each channel's importance is measured using LFH (Label-Free-Hessian)
which approximates the Hessian of the loss function w.r.t model's weights.

In this example, we've used just one score approximation for efficiency. Although this is less time-consuming, it's worth noting that using multiple approximations would yield more precise importance scores in real-world applications. However, this precision comes with a trade-off in terms of longer processing times.

The result is a pruned model and associated pruning information, which includes details about the pruning masks and scores for each layer.

In [None]:
num_score_approximations = 1

target_platform_cap = get_tpc()
pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(
        model=model,
        target_resource_utilization=resource_utilization,
        representative_data_gen=representative_data_gen,
        target_platform_capabilities=target_platform_cap,
        pruning_config=mct.pruning.PruningConfig(num_score_approximations=num_score_approximations)
    )

### Pruned Model Properties

As before, we can use Keras model's API to observe the new architecture and details of the pruned model:

In [None]:
pruned_model.summary()

## Retraining Pruned Model

After pruning models, it's common to observe a temporary drop in the model's accuracy. This decline directly results from reducing the model's complexity through pruning.

In [None]:
pruned_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
pruned_model.evaluate(test_images, test_labels)

However, to recover the performance, we retrain the pruned model, allowing it to adapt to its new, compressed architecture. The model can regain, and sometimes even surpass, its original accuracy through retraining.

In [None]:
pruned_model.fit(train_images, train_labels, epochs=6, validation_data=(test_images, test_labels))
pruned_model.evaluate(test_images, test_labels)

Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
