# Structured Pruning of a Fully-Connected PyTorch Model using the Model Compression Toolkit (MCT)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_pruning_mnist.ipynb)

## Overview
This tutorial provides a step-by-step guide to training, pruning, and finetuning a PyTorch fully connected neural network model using the Model Compression Toolkit (MCT). We will start by building and training the model from scratch on the MNIST dataset, followed by applying structured pruning to reduce the model size.

## Summary
In this tutorial, we will cover:

1. **Training a PyTorch model on MNIST:** We'll begin by constructing a basic fully connected neural network and training it on the MNIST dataset. 
2. **Applying structured pruning:** We'll introduce a pruning technique to reduce model size while maintaining performance. 
3. **Finetuning the pruned model:** After pruning, we'll finetune the model to recover any lost accuracy. 
4. **Evaluating the pruned model:** We'll evaluate the pruned model’s performance and compare it to the original model.

## Setup
Install the relevant packages:

In [None]:
!pip install -q torch torchvision

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import datasets, transforms

## Train a Pytorch classifier model on MNIST
Next, we'll define a function to train our neural network model. This function will handle the training loop, including forward propagation, loss calculation, backpropagation, and updating the model parameters. Additionally, we'll evaluate the model's performance on the validation dataset at the end of each epoch to monitor its accuracy. The following code snippets are adapted from the official [PyTorch examples](https://github.com/pytorch/examples/blob/main/mnist/main.py).

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    model.to(device)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
    
    return accuracy 

random_seed = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

## Creating a Fully-Connected Model
In this section, we create a simple example of a fully connected model to demonstrate the pruning process. It consists of three linear layers with 128, 64, and 10 neurons.

In [None]:
# Define the Fully-Connected Model
class FCModel(nn.Module):
    def __init__(self):
        super(FCModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc_layers = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.fc_layers(x)
        output = F.log_softmax(logits, dim=1)
        return output

## Loading and Preprocessing MNIST Dataset
Let's define the dataset loaders to retrieve the train and test parts of the MNIST dataset, including preprocessing:

In [None]:
batch_size = 128
test_batch_size = 1000

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset_folder = './mnist'
train_dataset = datasets.MNIST(dataset_folder, train=True, download=True,
                   transform=transform)
test_dataset = datasets.MNIST(dataset_folder, train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, pin_memory=True, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, pin_memory=True,  batch_size=test_batch_size, shuffle=False)

## Training the Dense Model
We will now train the dense model using the MNIST dataset.

In [None]:
epochs = 6
lr = 0.001

dense_model = FCModel().to(device)
optimizer = optim.Adam(dense_model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
    train(dense_model, device, train_loader, optimizer, epoch)
    test(dense_model, device, test_loader)

## Dense Model Properties
We will display our model's architecture, including layers, their types, and the number of parameters.
Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers  have a higher number of channels compared to later layers, offering more opportunities for pruning without affecting accuracy significantly. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.

In [None]:
def display_model_params(model):
    model_params = sum(p.numel() for p in model.parameters())
    for name, module in model.named_modules():
        module_params = sum(p.numel() for p in module.state_dict().values())
        if module_params > 0:
            print(f'{name} number of parameters {module_params}')
    print(f'\nTotal number of parameters {model_params}')
    return model_params

dense_model_params = display_model_params(dense_model)

## Create a Representative Dataset
We are creating a representative dataset to guide the pruning process for computing importance score for each channel:

In [None]:
n_iter=10

def representative_dataset_gen():
    dataloader_iter = iter(train_loader)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]
        

## Model Pruning
We are now ready to perform the actual pruning using MCT’s `pytorch_pruning_experimental` function. The model will be pruned based on the defined resource utilization constraints and the previously generated representative dataset.

Each channel’s importance is measured using the [LFH (Label-Free-Hessian) method](https://arxiv.org/abs/2309.11531), which approximates the Hessian of the loss function with respect to the model’s weights.

For efficiency, we use a single score approximation. Although less precise, it significantly reduces processing time compared to multiple approximations, which offer better accuracy but at the cost of longer runtimes.

MCT’s structured pruning will target the first two dense layers, where output channel reduction can be propagated to subsequent layers by adjusting their input channels accordingly.

The output is a pruned model along with pruning information, including layer-specific pruning masks and scores.

In [None]:
import model_compression_toolkit as mct
compression_ratio = 0.5

# Define Resource Utilization constraint for pruning. Each float32 parameter requires 4 bytes, hence we multiply the total parameter count by 4 to calculate the memory footprint.
target_resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_params * 4 * compression_ratio)

# Define a pruning configuration
pruning_config=mct.pruning.PruningConfig(num_score_approximations=1)

# Prune the model
pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(
    model=dense_model,
    target_resource_utilization=target_resource_utilization, 
    representative_data_gen=representative_dataset_gen, 
    pruning_config=pruning_config)

### Model after pruning
Let us view the model after the pruning operation and check the accuracy. We can see that pruning process caused a degradation in accuracy.

In [None]:
pruned_model_nparams = display_model_params(pruned_model)
acc_before_finetuning = test(pruned_model, device, test_loader)
print(f'Pruned model accuracy before finetuning {acc_before_finetuning}%')

## Finetuning the Pruned Model
After pruning, we often need to finetune the model to recover any lost performance.

In [None]:
optimizer = optim.Adam(pruned_model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
    train(pruned_model, device, train_loader, optimizer, epoch)
    test(pruned_model, device, test_loader)

Now, we can export the quantized model to ONNX:

In [None]:
mct.exporter.pytorch_export_model(pruned_model, save_model_path='qmodel.onnx', repr_dataset=representative_dataset_gen)

## Conclusions
In this tutorial, we demonstrated the process of training, pruning, and finetuning a neural network model using MCT. We began by setting up our environment and loading the dataset, followed by building and training a fully connected neural network. We then introduced the concept of model pruning, specifically targeting the first two dense layers to efficiently reduce the model's memory footprint by 50%. After applying structured pruning, we evaluated the pruned model's performance and concluded the tutorial by fine-tuning the pruned model to recover any lost accuracy due to the pruning process. This tutorial provided a hands-on approach to model optimization through pruning, showcasing the balance between model size, performance, and efficiency.

## Copyrights
Copyright 2024 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
