In [1]:
from torchvision import datasets

In [2]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True)

In [3]:
from torchvision import transforms
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

In [4]:
import torch
imgs = torch.stack([img for img, _ in train_val_dataset], dim =0)
mean = imgs.view(1, -1).mean(dim =1)
std = imgs.view(1, -1).std(dim =1)
mean, std

(tensor([0.1307]), tensor([0.3081]))

In [5]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

In [6]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(54000, 6000, 10000)

In [7]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
len(train_dataloader), len(val_dataloader), len(test_dataloader) 

(1688, 188, 313)

In [8]:
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5V1(nn.Module):
    def __init__(self):
        super(LeNet5V1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [9]:
from torchmetrics import Accuracy
model_lenet5v1 = LeNet5V1()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr = 0.001)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [10]:
from tqdm.notebook import tqdm

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.18226| Train acc:  0.94485| Val loss:  0.04386| Val acc:  0.98787
Epoch: 1| Train loss:  0.08170| Train acc:  0.97499| Val loss:  0.04052| Val acc:  0.98703
Epoch: 2| Train loss:  0.06159| Train acc:  0.98104| Val loss:  0.03223| Val acc:  0.99069
Epoch: 3| Train loss:  0.05199| Train acc:  0.98378| Val loss:  0.02978| Val acc:  0.99169
Epoch: 4| Train loss:  0.04244| Train acc:  0.98704| Val loss:  0.03371| Val acc:  0.99202
Epoch: 5| Train loss:  0.03875| Train acc:  0.98771| Val loss:  0.04058| Val acc:  0.99152
Epoch: 6| Train loss:  0.03186| Train acc:  0.98963| Val loss:  0.03737| Val acc:  0.99119
Epoch: 7| Train loss:  0.03265| Train acc:  0.98985| Val loss:  0.02814| Val acc:  0.99352
Epoch: 8| Train loss:  0.02709| Train acc:  0.99180| Val loss:  0.03764| Val acc:  0.99136
Epoch: 9| Train loss:  0.02488| Train acc:  0.99213| Val loss:  0.03597| Val acc:  0.99019
Epoch: 10| Train loss:  0.02497| Train acc:  0.99221| Val loss:  0.03810| Val acc:  0.9898

In [11]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [12]:
orig_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {orig_params} trainable parameters")

Unpruned LeNet-5 model has 1199882 trainable parameters


In [13]:
for layer, param in model_lenet5v1.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([32, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([32])
layer.name: conv2.weight & param.shape = torch.Size([64, 32, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([64])
layer.name: fc1.weight & param.shape = torch.Size([128, 9216])
layer.name: fc1.bias & param.shape = torch.Size([128])
layer.name: fc2.weight & param.shape = torch.Size([10, 128])
layer.name: fc2.bias & param.shape = torch.Size([10])


In [14]:
for layer_name in model_lenet5v1.state_dict().keys():
    print(layer_name, model_lenet5v1.state_dict()[layer_name].shape)

conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 9216])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])


In [15]:
model_lenet5v1.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [16]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.conv2.weight == 0) / model.conv2.weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.fc1.weight == 0) / model.fc1.weight.nelement()) * 100
    op_sparsity = (torch.sum(model.fc2.weight == 0) / model.fc2.weight.nelement()) * 100

    num = torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0)
    denom = model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement()

    global_sparsity = num/denom * 100

    return global_sparsity

In [17]:
print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")

LeNet-5 global sparsity = 0.00%


In [18]:
import torch.nn.utils.prune as prune

for iter_prune_round in range(2):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
    
    # Prune layer-wise in a structured manner-
    prune.ln_structured(model_lenet5v1.conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc2, name = "weight", amount = 0.1, n = 2, dim = 0)

    # Print current global sparsity level-
    print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")
    
    
    # Fine-training loop-
    print("\nFine-tuning pruned model to recover model's performance\n")
    
    
    for epoch in range(EPOCHS):
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            
            model_lenet5v1.train()
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            train_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            train_acc += acc
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)
            
        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model_lenet5v1.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)
                
                y_pred = model_lenet5v1(X)
                
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                
                acc = accuracy(y_pred, y)
                val_acc += acc
                
            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)
        
        print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")




Iterative Global pruning round = 1
LeNet-5 global sparsity = 10.14%

Fine-tuning pruned model to recover model's performance

Epoch: 0| Train loss:  0.02905| Train acc:  0.99104| Val loss:  0.04188| Val acc:  0.99252
Epoch: 1| Train loss:  0.02143| Train acc:  0.99311| Val loss:  0.04164| Val acc:  0.99169
Epoch: 2| Train loss:  0.01939| Train acc:  0.99380| Val loss:  0.04199| Val acc:  0.99036
Epoch: 3| Train loss:  0.01914| Train acc:  0.99369| Val loss:  0.04699| Val acc:  0.99036
Epoch: 4| Train loss:  0.01786| Train acc:  0.99409| Val loss:  0.05110| Val acc:  0.99053
Epoch: 5| Train loss:  0.01639| Train acc:  0.99502| Val loss:  0.04721| Val acc:  0.99053
Epoch: 6| Train loss:  0.01603| Train acc:  0.99472| Val loss:  0.05711| Val acc:  0.99069
Epoch: 7| Train loss:  0.01778| Train acc:  0.99441| Val loss:  0.04669| Val acc:  0.99202
Epoch: 8| Train loss:  0.01610| Train acc:  0.99483| Val loss:  0.05230| Val acc:  0.99053
Epoch: 9| Train loss:  0.01462| Train acc:  0.99558| 

In [26]:
from pathlib import Path

MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "lenet5_v1_mnist_prune.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# Saving the model
print(f"Saving the model: {MODEL_SAVE_PATH}")
torch.save(obj=model_lenet5v1.state_dict(), f=MODEL_SAVE_PATH)

# Loading the saved model
model_lenet5_v1_mnist_prune_loaded = LeNet5V1()
# Load the state dict
state_dict = torch.load(MODEL_SAVE_PATH)

# Remove the suffixes '_orig' and '_mask'
new_state_dict = {}
for key in state_dict:
    if '_orig' in key:
        new_key = key.replace('_orig', '')
        new_state_dict[new_key] = state_dict[key]
    # Skip '_mask' keys as they're not needed in normal models

# Save the new state dict
torch.save(new_state_dict, MODEL_SAVE_PATH)

# Load the cleaned state dict into your model
model_lenet5_v1_mnist_prune_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))

Saving the model: models\lenet5_v1_mnist_prune.pth


  state_dict = torch.load(MODEL_SAVE_PATH)
  model_lenet5_v1_mnist_prune_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))


RuntimeError: Error(s) in loading state_dict for LeNet5V1:
	Missing key(s) in state_dict: "conv1.bias", "conv2.bias", "fc1.bias", "fc2.bias". 

In [28]:
test_loss, test_acc = 0, 0

model_lenet5v1.to(device)

model_lenet5v1.eval()
with torch.inference_mode():
    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)
        y_pred = model_lenet5v1(X)
        
        test_loss += loss_fn(y_pred, y)
        test_acc += accuracy(y_pred, y)
        
    test_loss /= len(test_dataloader)
    test_acc /= len(test_dataloader)

print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")

Test loss:  0.18088| Test acc:  0.89607
