In [2]:
from torchvision import datasets

In [3]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True)

In [4]:
from torchvision import transforms
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

In [8]:
import torch
imgs = torch.stack([img for img, _ in train_val_dataset], dim =0)
mean = imgs.view(1, -1).mean(dim =1)
std = imgs.view(1, -1).std(dim =1)
mean, std

(tensor([0.1307]), tensor([0.3081]))

In [9]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

In [12]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(54000, 6000, 10000)

In [14]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
len(train_dataloader), len(val_dataloader), len(test_dataloader) 

(1688, 188, 313)

In [16]:
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5V1(nn.Module):
    def __init__(self):
        super(LeNet5V1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [18]:
from torchmetrics import Accuracy
model_lenet5v1 = LeNet5V1()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr = 0.001)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [20]:
from tqdm.notebook import tqdm

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.18371| Train acc:  0.94542| Val loss:  0.04981| Val acc:  0.98554
Epoch: 1| Train loss:  0.08385| Train acc:  0.97427| Val loss:  0.04055| Val acc:  0.98803
Epoch: 2| Train loss:  0.06440| Train acc:  0.98025| Val loss:  0.04389| Val acc:  0.98803
Epoch: 3| Train loss:  0.05160| Train acc:  0.98421| Val loss:  0.03794| Val acc:  0.98920
Epoch: 4| Train loss:  0.04361| Train acc:  0.98697| Val loss:  0.03645| Val acc:  0.98920
Epoch: 5| Train loss:  0.03724| Train acc:  0.98854| Val loss:  0.04338| Val acc:  0.98836
Epoch: 6| Train loss:  0.03686| Train acc:  0.98861| Val loss:  0.03385| Val acc:  0.98969
Epoch: 7| Train loss:  0.03017| Train acc:  0.99050| Val loss:  0.02706| Val acc:  0.99269
Epoch: 8| Train loss:  0.03038| Train acc:  0.99056| Val loss:  0.03002| Val acc:  0.99152
Epoch: 9| Train loss:  0.02760| Train acc:  0.99104| Val loss:  0.03457| Val acc:  0.99019
Epoch: 10| Train loss:  0.02399| Train acc:  0.99200| Val loss:  0.04935| Val acc:  0.9887

In [22]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [24]:
orig_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {orig_params} trainable parameters")

Unpruned LeNet-5 model has 1199882 trainable parameters


In [26]:
for layer, param in model_lenet5v1.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([32, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([32])
layer.name: conv2.weight & param.shape = torch.Size([64, 32, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([64])
layer.name: fc1.weight & param.shape = torch.Size([128, 9216])
layer.name: fc1.bias & param.shape = torch.Size([128])
layer.name: fc2.weight & param.shape = torch.Size([10, 128])
layer.name: fc2.bias & param.shape = torch.Size([10])


In [28]:
for layer_name in model_lenet5v1.state_dict().keys():
    print(layer_name, model_lenet5v1.state_dict()[layer_name].shape)

conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 9216])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])


In [30]:
model_lenet5v1.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [32]:
parameters_to_prune = (
    (model_lenet5v1.conv1, 'weight'),
    (model_lenet5v1.conv2, 'weight'),
    (model_lenet5v1.fc1, 'weight'),
    (model_lenet5v1.fc2, 'weight')
)

In [34]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.conv2.weight == 0) / model.conv2.weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.fc1.weight == 0) / model.fc1.weight.nelement()) * 100
    op_sparsity = (torch.sum(model.fc2.weight == 0) / model.fc2.weight.nelement()) * 100

    num = torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0)
    denom = model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement()

    global_sparsity = num/denom * 100

    print(f"Lenet 5 Global sparsity = {global_sparsity:.2f}%")
    return None

In [36]:
compute_sparsity(model_lenet5v1)

Lenet 5 Global sparsity = 0.00%


In [38]:
count_params(model_lenet5v1)

tensor(1199882, device='cuda:0')

In [40]:
prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]

In [42]:
import torch.nn.utils.prune as prune

for iter_prune_round in range(5):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
    
    # Prune layer-wise in a structured manner-
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method = prune.L1Unstructured,
        amount = prune_rates_global[iter_prune_round]
        
    )

    # Print current global sparsity level-
    compute_sparsity(model_lenet5v1)
    
    
    # Fine-training loop-
    print("\nFine-tuning pruned model to recover model's performance\n")
    
    
    for epoch in range(EPOCHS):
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            
            model_lenet5v1.train()
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            train_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            train_acc += acc
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)
            
        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model_lenet5v1.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)
                
                y_pred = model_lenet5v1(X)
                
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                
                acc = accuracy(y_pred, y)
                val_acc += acc
                
            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)
        
        print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")




Iterative Global pruning round = 1
Lenet 5 Global sparsity = 20.00%

Fine-tuning pruned model to recover model's performance

Epoch: 0| Train loss:  0.02198| Train acc:  0.99319| Val loss:  0.03269| Val acc:  0.99219
Epoch: 1| Train loss:  0.02060| Train acc:  0.99341| Val loss:  0.03065| Val acc:  0.99202
Epoch: 2| Train loss:  0.01814| Train acc:  0.99458| Val loss:  0.03360| Val acc:  0.99235
Epoch: 3| Train loss:  0.01826| Train acc:  0.99426| Val loss:  0.03459| Val acc:  0.99252
Epoch: 4| Train loss:  0.01794| Train acc:  0.99413| Val loss:  0.04571| Val acc:  0.99152
Epoch: 5| Train loss:  0.01593| Train acc:  0.99513| Val loss:  0.03782| Val acc:  0.99036
Epoch: 6| Train loss:  0.01826| Train acc:  0.99461| Val loss:  0.04089| Val acc:  0.99169
Epoch: 7| Train loss:  0.01572| Train acc:  0.99446| Val loss:  0.04173| Val acc:  0.99152
Epoch: 8| Train loss:  0.01528| Train acc:  0.99496| Val loss:  0.03508| Val acc:  0.99252
Epoch: 9| Train loss:  0.01436| Train acc:  0.99537| 