In [2]:
from torchvision import datasets

In [4]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True)

In [6]:
from torchvision import transforms
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

In [8]:
import torch
imgs = torch.stack([img for img, _ in train_val_dataset], dim =0)
mean = imgs.view(1, -1).mean(dim =1)
std = imgs.view(1, -1).std(dim =1)
mean, std

(tensor([0.1307]), tensor([0.3081]))

In [10]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

In [12]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(54000, 6000, 10000)

In [14]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
len(train_dataloader), len(val_dataloader), len(test_dataloader) 

(1688, 188, 313)

In [16]:
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5V1(nn.Module):
    def __init__(self):
        super(LeNet5V1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [18]:
from torchmetrics import Accuracy
model_lenet5v1 = LeNet5V1()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr = 0.001)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [20]:
from tqdm.notebook import tqdm

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.19778| Train acc:  0.93974| Val loss:  0.05571| Val acc:  0.98188
Epoch: 1| Train loss:  0.08873| Train acc:  0.97380| Val loss:  0.04401| Val acc:  0.98703
Epoch: 2| Train loss:  0.07138| Train acc:  0.97860| Val loss:  0.03964| Val acc:  0.98770
Epoch: 3| Train loss:  0.05797| Train acc:  0.98202| Val loss:  0.04279| Val acc:  0.98770
Epoch: 4| Train loss:  0.04868| Train acc:  0.98478| Val loss:  0.03821| Val acc:  0.98920
Epoch: 5| Train loss:  0.04514| Train acc:  0.98586| Val loss:  0.03931| Val acc:  0.98770
Epoch: 6| Train loss:  0.03943| Train acc:  0.98730| Val loss:  0.03329| Val acc:  0.99119
Epoch: 7| Train loss:  0.03380| Train acc:  0.98936| Val loss:  0.03750| Val acc:  0.98936
Epoch: 8| Train loss:  0.03362| Train acc:  0.98891| Val loss:  0.03710| Val acc:  0.99036
Epoch: 9| Train loss:  0.03092| Train acc:  0.99002| Val loss:  0.03867| Val acc:  0.99136
Epoch: 10| Train loss:  0.02749| Train acc:  0.99093| Val loss:  0.03850| Val acc:  0.9908

In [22]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [24]:
orig_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {orig_params} trainable parameters")

Unpruned LeNet-5 model has 1199882 trainable parameters


In [26]:
for layer, param in model_lenet5v1.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([32, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([32])
layer.name: conv2.weight & param.shape = torch.Size([64, 32, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([64])
layer.name: fc1.weight & param.shape = torch.Size([128, 9216])
layer.name: fc1.bias & param.shape = torch.Size([128])
layer.name: fc2.weight & param.shape = torch.Size([10, 128])
layer.name: fc2.bias & param.shape = torch.Size([10])


In [28]:
for layer_name in model_lenet5v1.state_dict().keys():
    print(layer_name, model_lenet5v1.state_dict()[layer_name].shape)

conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([128, 9216])
fc1.bias torch.Size([128])
fc2.weight torch.Size([10, 128])
fc2.bias torch.Size([10])


In [30]:
model_lenet5v1.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [36]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.conv2.weight == 0) / model.conv2.weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.fc1.weight == 0) / model.fc1.weight.nelement()) * 100
    op_sparsity = (torch.sum(model.fc2.weight == 0) / model.fc2.weight.nelement()) * 100

    num = torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0)
    denom = model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement()

    global_sparsity = num/denom * 100

    return global_sparsity

In [38]:
print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")

LeNet-5 global sparsity = 0.00%


In [40]:
import torch.nn.utils.prune as prune
for name, module in model_lenet5v1.named_modules():
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Linear) and name != 'fc2':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear) and name == 'fc2':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)

In [42]:
print(f"LeNet-5 global sparsity = {compute_sparsity(model_lenet5v1):.2f}%")

LeNet-5 global sparsity = 19.68%


In [44]:
new_params = count_params(model_lenet5v1)
print(f"Unpruned LeNet-5 model has {new_params} trainable parameters")

Unpruned LeNet-5 model has 1199882 trainable parameters


In [46]:
for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.02337| Train acc:  0.99232| Val loss:  0.04535| Val acc:  0.99069
Epoch: 1| Train loss:  0.02231| Train acc:  0.99276| Val loss:  0.04293| Val acc:  0.99219
Epoch: 2| Train loss:  0.02228| Train acc:  0.99330| Val loss:  0.04249| Val acc:  0.99186
Epoch: 3| Train loss:  0.01991| Train acc:  0.99341| Val loss:  0.04319| Val acc:  0.99102
Epoch: 4| Train loss:  0.01905| Train acc:  0.99384| Val loss:  0.04668| Val acc:  0.99136
Epoch: 5| Train loss:  0.01986| Train acc:  0.99374| Val loss:  0.04799| Val acc:  0.99003
Epoch: 6| Train loss:  0.01883| Train acc:  0.99384| Val loss:  0.05339| Val acc:  0.99086
Epoch: 7| Train loss:  0.01807| Train acc:  0.99437| Val loss:  0.05009| Val acc:  0.99119
Epoch: 8| Train loss:  0.01886| Train acc:  0.99391| Val loss:  0.04692| Val acc:  0.99086
Epoch: 9| Train loss:  0.01750| Train acc:  0.99459| Val loss:  0.05521| Val acc:  0.99053
Epoch: 10| Train loss:  0.01812| Train acc:  0.99450| Val loss:  0.04829| Val acc:  0.9896