# Pruning and quantization combined

<!-- ## Prerequisites

1. It is recommended to create a virtual environment

2. Configure basic environment dependencies

3. If you have a CUDA-supported GPU, add -->

## Preparation

1. Use **CIFAR-10** dataset
>The CIFAR-10 dataset is a widely used benchmark in machine learning and computer vision, consisting of 60,000 32x32 color images divided into 10 categories, with 6,000 images in each category.

2. Use **ResNet18** pre-trained model
>The residual neural network (also known as residual network or ResNet) is a pioneering deep learning model in which the weight layer references the layer input to learn the residual function. It was developed for image recognition in 2015 and won the ImageNet Large Scale Visual Recognition Challenge (ILSVRC) that year.
For ResNet-18 model structure visualization and other information, please refer to Li Mu's [d2l](https://d2l.ai/chapter_convolutional-modern/resnet.html)

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune
import torch.optim as optim


In [None]:
# Load and normalize CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Divide the dataset into training and testing sets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [9]:
# Download ResNet-18 pre-trained model parameters
resnet18 = torchvision.models.resnet18()
# ResNet-18 model structure
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
def prune_model(model, pruning_rate=0.1):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            
# Use unstructured L1 norm pruning
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)
            
            prune.remove(module, 'weight')

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.9)

def train_model(model, epochs=10, prune_every_n_epochs=5):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            if i % 2000 == 1999:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
                running_loss = 0.0
                
        if (epoch + 1) % prune_every_n_epochs == 0:
            print(f'Pruning after epoch {epoch + 1}')
            prune_model(model, pruning_rate=0.1)
            print('Pruning done.')

train_model(resnet18)

In [None]:
# Quantize the model to int8 type
resnet18_int8 = torch.quantization.convert(resnet18, inplace=False)

# Save the quantized model
# torch.save(resnet18_int8.state_dict(), 'resnet18_int8.pth')

# Load and evaluate the quantized model
resnet18_int8_loaded = torchvision.models.resnet18()
resnet18_int8_loaded.qconfig = torch.quantization.get_default_qconfig('fbgemm')
resnet18_int8_loaded = torch.quantization.prepare(resnet18_int8_loaded)
resnet18_int8_loaded = torch.quantization.convert(resnet18_int8_loaded)
resnet18_int8_loaded.load_state_dict(torch.load('resnet18_quantized.pth'))
