In [1]:
from torchvision import datasets

In [2]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True)

In [3]:
from torchvision import transforms
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

In [4]:
import torch
imgs = torch.stack([img for img, _ in train_val_dataset], dim =0)
mean = imgs.view(1, -1).mean(dim =1)
std = imgs.view(1, -1).std(dim =1)
mean, std

(tensor([0.1307]), tensor([0.3081]))

In [5]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

In [6]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(54000, 6000, 10000)

In [7]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
len(train_dataloader), len(val_dataloader), len(test_dataloader) 

(1688, 188, 313)

In [8]:
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5V1(nn.Module):
    def __init__(self):
        super(LeNet5V1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [22]:
from torchmetrics import Accuracy
model_lenet5v1 = LeNet5V1()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr = 0.01)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [10]:
# from tqdm.notebook import tqdm

# # device-agnostic setup
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# accuracy = accuracy.to(device)
# model_lenet5v1 = model_lenet5v1.to(device)

# EPOCHS = 12

# for epoch in tqdm(range(EPOCHS)):
#     # Training loop
#     train_loss, train_acc = 0.0, 0.0
#     for X, y in train_dataloader:
#         X, y = X.to(device), y.to(device)
        
#         model_lenet5v1.train()
        
#         y_pred = model_lenet5v1(X)
        
#         loss = loss_fn(y_pred, y)
#         train_loss += loss.item()
        
#         acc = accuracy(y_pred, y)
#         train_acc += acc
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#     train_loss /= len(train_dataloader)
#     train_acc /= len(train_dataloader)
        
#     # Validation loop
#     val_loss, val_acc = 0.0, 0.0
#     model_lenet5v1.eval()
#     with torch.inference_mode():
#         for X, y in val_dataloader:
#             X, y = X.to(device), y.to(device)
            
#             y_pred = model_lenet5v1(X)
            
#             loss = loss_fn(y_pred, y)
#             val_loss += loss.item()
            
#             acc = accuracy(y_pred, y)
#             val_acc += acc
            
#         val_loss /= len(val_dataloader)
#         val_acc /= len(val_dataloader)
    
#     print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

In [24]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [26]:
orig_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {orig_params} trainable parameters")

Unpruned LeNet-5 model has 1199881 trainable parameters


In [28]:
for layer, param in model_lenet5v1.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([32, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([32])
layer.name: conv2.weight & param.shape = torch.Size([64, 32, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([64])
layer.name: fc1.weight & param.shape = torch.Size([128, 9216])
layer.name: fc1.bias & param.shape = torch.Size([128])
layer.name: fc2.weight & param.shape = torch.Size([10, 128])
layer.name: fc2.bias & param.shape = torch.Size([10])


In [30]:
for layer_name in model_lenet5v1.state_dict().keys():
    print(layer_name, model_lenet5v1.state_dict()[layer_name].shape)

conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 9216])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])


In [32]:
model_lenet5v1.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [34]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.conv2.weight == 0) / model.conv2.weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.fc1.weight == 0) / model.fc1.weight.nelement()) * 100
    op_sparsity = (torch.sum(model.fc2.weight == 0) / model.fc2.weight.nelement()) * 100

    num = torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0)
    denom = model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement()

    global_sparsity = num/denom * 100

    return global_sparsity

In [36]:
print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")

LeNet-5 global sparsity = 0.00%


In [38]:
import torch.nn.utils.prune as prune
from tqdm.notebook import tqdm


device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for iter_prune_round in range(2):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
    
    # Prune layer-wise in a structured manner-
    prune.ln_structured(model_lenet5v1.conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc1, name = "weight", amount = 0.1, n = 2, dim = 0)
    prune.ln_structured(model_lenet5v1.fc2, name = "weight", amount = 0.1, n = 2, dim = 0)

    # Print current global sparsity level-
    print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")
    
    
    # Fine-training loop-
    print("\nFine-tuning pruned model to recover model's performance\n")
    
    
    for epoch in range(EPOCHS):
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            
            model_lenet5v1.train()
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            train_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            train_acc += acc
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)
            
        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model_lenet5v1.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)
                
                y_pred = model_lenet5v1(X)
                
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                
                acc = accuracy(y_pred, y)
                val_acc += acc
                
            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)
        
        print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")




Iterative Global pruning round = 1
LeNet-5 global sparsity = 10.14%

Fine-tuning pruned model to recover model's performance

Epoch: 0| Train loss:  0.64788| Train acc:  0.79617| Val loss:  0.23024| Val acc:  0.93068
Epoch: 1| Train loss:  0.42809| Train acc:  0.86752| Val loss:  0.21763| Val acc:  0.93251
Epoch: 2| Train loss:  0.40774| Train acc:  0.87554| Val loss:  0.22863| Val acc:  0.94182
Epoch: 3| Train loss:  0.38819| Train acc:  0.88328| Val loss:  0.19034| Val acc:  0.94249
Epoch: 4| Train loss:  0.37631| Train acc:  0.88585| Val loss:  0.17694| Val acc:  0.94914
Epoch: 5| Train loss:  0.38099| Train acc:  0.88603| Val loss:  0.17760| Val acc:  0.94232
Epoch: 6| Train loss:  0.36187| Train acc:  0.89242| Val loss:  0.19060| Val acc:  0.94332
Epoch: 7| Train loss:  0.35817| Train acc:  0.89261| Val loss:  0.17934| Val acc:  0.94764
Epoch: 8| Train loss:  0.35484| Train acc:  0.89436| Val loss:  0.17289| Val acc:  0.94747
Epoch: 9| Train loss:  0.35069| Train acc:  0.89333| 

In [20]:
from pathlib import Path

MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "lenet5_v1_mnist_prune.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME


# Saving the model
print(f"Saving the model: {MODEL_SAVE_PATH}")
torch.save(obj=model_lenet5v1.state_dict(), f=MODEL_SAVE_PATH)

# Loading the saved model
model_lenet5_v1_mnist_prune_loaded = LeNet5V1()



# Load the cleaned state dict into your model
model_lenet5_v1_mnist_prune_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))

OrderedDict({'conv1.bias': tensor([-0.1765, -0.2138, -0.2956, -0.4034, -0.3447, -0.4388, -0.3617, -0.3391,
        -0.2558, -0.1817, -0.0759, -0.2134, -0.3188, -0.2356, -0.2306, -0.4459,
        -0.1513, -0.1414, -0.2197, -0.1255, -0.4387, -0.1542, -0.6356, -0.0012,
        -0.3279, -0.2402, -0.0015, -0.1514, -0.4706, -0.3682, -0.3107, -0.3700],
       device='cuda:0'), 'conv1.weight_orig': tensor([[[[-0.0638,  0.2512,  0.0241],
          [-0.0077, -0.0086,  0.2295],
          [-0.2666, -0.2892,  0.1790]]],


        [[[ 0.1059,  0.2335, -0.4714],
          [-0.2065,  0.3846, -0.2992],
          [-0.3419,  0.1817,  0.0993]]],


        [[[ 0.2011, -0.2696,  0.3057],
          [-0.0553, -0.4592,  0.2178],
          [ 0.2844, -0.1168, -0.0028]]],


        [[[ 0.1878, -0.0242, -0.0665],
          [ 0.1380, -0.1186, -0.0161],
          [-0.2211, -0.2948,  0.3633]]],


        [[[ 0.2028, -0.1597,  0.3027],
          [-0.3570,  0.1262, -0.3636],
          [ 0.1878, -0.1347,  0.1959]]],


 

In [None]:
test_loss, test_acc = 0, 0

model_lenet5v1.to(device)

model_lenet5v1.eval()
with torch.inference_mode():
    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)
        y_pred = model_lenet5v1(X)
        
        test_loss += loss_fn(y_pred, y)
        test_acc += accuracy(y_pred, y)
        
    test_loss /= len(test_dataloader)
    test_acc /= len(test_dataloader)

print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")