In [1]:
from torchvision import datasets

In [9]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True)

In [11]:
from torchvision import transforms
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

In [13]:
import torch
imgs = torch.stack([img for img, _ in train_val_dataset], dim =0)
mean = imgs.view(1, -1).mean(dim =1)
std = imgs.view(1, -1).std(dim =1)
mean, std

(tensor([0.1307]), tensor([0.3081]))

In [15]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

In [17]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(54000, 6000, 10000)

In [19]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
len(train_dataloader), len(val_dataloader), len(test_dataloader) 

(1688, 188, 313)

In [21]:
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5V1(nn.Module):
    def __init__(self):
        super(LeNet5V1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [23]:
from torchmetrics import Accuracy
model_lenet5v1 = LeNet5V1()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr = 0.001)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [25]:
from tqdm.notebook import tqdm

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.19825| Train acc:  0.94059| Val loss:  0.07203| Val acc:  0.97955
Epoch: 1| Train loss:  0.08707| Train acc:  0.97367| Val loss:  0.05336| Val acc:  0.98437
Epoch: 2| Train loss:  0.07069| Train acc:  0.97954| Val loss:  0.05107| Val acc:  0.98703
Epoch: 3| Train loss:  0.05812| Train acc:  0.98306| Val loss:  0.05112| Val acc:  0.98604
Epoch: 4| Train loss:  0.04928| Train acc:  0.98582| Val loss:  0.04461| Val acc:  0.98903
Epoch: 5| Train loss:  0.04234| Train acc:  0.98699| Val loss:  0.04914| Val acc:  0.98903
Epoch: 6| Train loss:  0.03798| Train acc:  0.98778| Val loss:  0.04580| Val acc:  0.98803
Epoch: 7| Train loss:  0.03255| Train acc:  0.98971| Val loss:  0.05236| Val acc:  0.98836
Epoch: 8| Train loss:  0.03072| Train acc:  0.99061| Val loss:  0.05877| Val acc:  0.98787
Epoch: 9| Train loss:  0.02961| Train acc:  0.99032| Val loss:  0.05240| Val acc:  0.98886
Epoch: 10| Train loss:  0.02716| Train acc:  0.99154| Val loss:  0.04985| Val acc:  0.9900

In [27]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [29]:
orig_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {orig_params} trainable parameters")

Unpruned LeNet-4 model has 1199882 trainable parameters


In [31]:
for layer, param in model_lenet5v1.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([32, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([32])
layer.name: conv2.weight & param.shape = torch.Size([64, 32, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([64])
layer.name: fc1.weight & param.shape = torch.Size([128, 9216])
layer.name: fc1.bias & param.shape = torch.Size([128])
layer.name: fc2.weight & param.shape = torch.Size([10, 128])
layer.name: fc2.bias & param.shape = torch.Size([10])


In [33]:
for layer_name in model_lenet5v1.state_dict().keys():
    print(layer_name, model_lenet5v1.state_dict()[layer_name].shape)

conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 9216])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])


In [47]:
model_lenet5v1.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [65]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.conv2.weight == 0) / model.conv2.weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.fc1.weight == 0) / model.fc1.weight.nelement()) * 100
    op_sparsity = (torch.sum(model.fc2.weight == 0) / model.fc2.weight.nelement()) * 100

    num = torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0)
    denom = model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement()

    global_sparsity = num/denom * 100

    return global_sparsity

In [67]:
print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")

LeNet-5 global sparsity = 0.00%


In [69]:
import torch.nn.utils.prune as prune

for iter_prune_round in range(10):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
    
    # Prune layer-wise in a structured manner-
    prune.ln_structured(model_lenet5v1.conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc2, name = "weight", amount = 0.1, n = 2, dim = 0)

    # Print current global sparsity level-
    print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")
    
    
    # Fine-training loop-
    print("\nFine-tuning pruned model to recover model's performance\n")
    
    
    for epoch in range(EPOCHS):
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            
            model_lenet5v1.train()
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            train_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            train_acc += acc
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)
            
        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model_lenet5v1.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)
                
                y_pred = model_lenet5v1(X)
                
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                
                acc = accuracy(y_pred, y)
                val_acc += acc
                
            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)
        
        print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")




Iterative Global pruning round = 1
LeNet-5 global sparsity = 10.14%

Fine-tuning pruned model to recover model's performance

Epoch: 0| Train loss:  0.03098| Train acc:  0.99024| Val loss:  0.04797| Val acc:  0.99069
Epoch: 1| Train loss:  0.02270| Train acc:  0.99265| Val loss:  0.05705| Val acc:  0.99003
Epoch: 2| Train loss:  0.02143| Train acc:  0.99271| Val loss:  0.05703| Val acc:  0.98953
Epoch: 3| Train loss:  0.01961| Train acc:  0.99382| Val loss:  0.05070| Val acc:  0.99069
Epoch: 4| Train loss:  0.02046| Train acc:  0.99332| Val loss:  0.06339| Val acc:  0.99019
Epoch: 5| Train loss:  0.02047| Train acc:  0.99335| Val loss:  0.06260| Val acc:  0.99019
Epoch: 6| Train loss:  0.01653| Train acc:  0.99454| Val loss:  0.05469| Val acc:  0.98986
Epoch: 7| Train loss:  0.01741| Train acc:  0.99445| Val loss:  0.05249| Val acc:  0.99019
Epoch: 8| Train loss:  0.01651| Train acc:  0.99496| Val loss:  0.05503| Val acc:  0.99086
Epoch: 9| Train loss:  0.01754| Train acc:  0.99450| 

In [None]:
pl