# Pruning 
Pruning can signficantly decrease the model size, make inference faster and easier to deploy, especially when combined with quantization and knowledge distillation. Pruning can remove unnecessary neurons, channels, layers, and entire filters. There are two types of pruning: 

    - Unstructured: Prune individual weights (sparse)
    - Structured : Remove entire filters/channels/layers (can be more hardware-friendly)

We're gonna try to prune resnet here. And because it has "res"idual connections, we need to be careful how we do it. Usually, we don't prune batch norm layer, those are important for stability. We can apply structured pruning to conv layers, and unstructured pruning to linear layers. Typically, when we do pruning, the accuracy can drop, so we may need to fine-tune the prune the model. Quantization can then follow to make the models even smaller. 

One thing to note about sparse models is that they're not necessarily faster. You need "sparse kernels" or post-training to utilize that sparsity. Unstructured pruning just set the weight to zero, it doesn't remove them. You need a backend that understand how to remove them. 

Sparse kernel is special operator that detect sparse weights, skips computations on zero weights, and uses compressed sparse row formats. Torch and ONNX have compilers that can help with that. 

We'll use pytorch to help us do the pruning in this notebook. 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import os 
import torch.nn.utils.prune as prune


import time 
import copy

In [3]:
# 1. Transforms for ResNet
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 2. Load STL10 test set (we'll use test set for evaluation)
train_ds = datasets.STL10(root="../data", split="train", download=True, transform=transform)
test_ds  = datasets.STL10(root="../data", split="test",  download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=2)

num_labels = 10

device = 'cpu'

In [53]:
# Load teacher model
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load("../data/teacher_resnet.pth"))

<All keys matched successfully>

In [15]:
model.eval()
correct = total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

print(f"Accuracy on STL10 test set: {correct / total * 100:.2f}%")

Accuracy on STL10 test set: 92.58%


# Time for pruning 

Torch doesn't actually zero the weight, it creates a mask that gets applied to the weights during the forward pass. We could also just remove the weights during pruning using: 

    prune.remove(module, "weight")


In [49]:
def count_pruned_weights(model):
    total, zero = 0, 0
    for module in model.modules():
        if hasattr(module, "weight_mask") and hasattr(module, "weight_orig"):
            mask = module.weight_mask
            numel = mask.numel()
            total += numel
            zero += (mask == 0).sum().item()
    return total, zero


def apply_unstructured_pruning(model, amount=0.5):
    """
    Apply unstructured L1-norm pruning to Conv2d and Linear layers.
    """
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name='weight', amount=amount)
    return model

In [62]:
# Load teacher model
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load("../data/teacher_resnet.pth"))

<All keys matched successfully>

In [64]:
model_pruned = apply_unstructured_pruning(model, amount=0.5)
total, zero = count_pruned_weights(model_pruned)
print(f"Total params: {total:,}")
print(f"Zeroed (pruned) params: {zero:,} ({100 * zero/total:.2f}%)")

Total params: 11,172,032
Zeroed (pruned) params: 5,586,016 (50.00%)


You can see that now there are these "weight_mask" in the model 

In [67]:
print(hasattr(model_pruned.layer1[0].conv1, "weight_orig"))   # True
print(hasattr(model_pruned.layer1[0].conv1, "weight_mask"))   # True

True
True


In [69]:
model_pruned.eval()
correct = total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        preds = model_pruned(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

print(f"Accuracy on STL10 test set: {correct / total * 100:.2f}%")

Accuracy on STL10 test set: 66.25%


Not surprisingly the perforamnce dropped. Let's fine-tune it and see if that can help recover some of the performance.  

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_pruned.fc.parameters(), lr=1e-3, momentum=0.9)

model.train()
for epoch in range(15):  
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model_pruned(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")


Epoch 1: Loss = 0.4014
Epoch 2: Loss = 0.3618
Epoch 3: Loss = 0.3361
Epoch 4: Loss = 0.3152


In [None]:
total, zero = count_pruned_weights(model_pruned)
print(f"Total params: {total:,}")
print(f"Zeroed (pruned) params: {zero:,} ({100 * zero/total:.2f}%)")

In [None]:
model_pruned.eval()
correct = total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        preds = model_pruned(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

print(f"Accuracy on STL10 test set: {correct / total * 100:.2f}%")