In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import pandas as pd


import matplotlib.pyplot as plt
import seaborn as sns

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [11]:
num_epochs = 100

batch_size = 128

In [13]:
train_dataset = torchvision.datasets.CIFAR10(root= './data', train = True, download=True)
test_dataset =  torchvision.datasets.CIFAR10(root= './data', train = False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
class_names = train_dataset.classes
class_names

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [17]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
cifar_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,])

In [19]:
train_dataset = torchvision.datasets.CIFAR10(root= './data', train = True, download=False, transform = cifar_transforms)
test_dataset =  torchvision.datasets.CIFAR10(root= './data', train = False, download=False, transform = cifar_transforms)

In [21]:
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset = train_dataset, lengths = [train_size, val_size])
len(train_dataset), len(val_dataset), len(test_dataset)

(45000, 5000, 10000)

In [23]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset =train_dataset, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(dataset= test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(dataset = val_dataset, batch_size=batch_size, shuffle=True)

In [25]:
import math
class VGG(nn.Module):
    def __init__(self, features):
        super(VGG,self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512,512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512,512),
            nn.ReLU(True),
            nn.Linear(512,10)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def make_layers(cfg):
    layers = []
    in_channels =3
    for out_channels in cfg:
        if out_channels == 'M':
            layers += [nn.MaxPool2d(kernel_size = 2, stride =2)]
        else:
            conv2d = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding =1)
            layers += [conv2d, nn.ReLU(inplace = True)]
            in_channels = out_channels
    return nn.Sequential(*layers)

cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

def vgg16():
    return VGG(make_layers(cfg))

In [27]:
model = vgg16()

In [29]:
from torchmetrics import Accuracy
model = vgg16()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05, momentum = 0.9, weight_decay = 5e-4)
accuracy = Accuracy(task='multiclass', num_classes=10)

In [31]:
def count_params(model):
    total_params = 0
    for layer_names, param in model.named_parameters():
        total_params += torch.count_nonzero(param.data)
    return total_params

In [33]:
orig_params = count_params(model)
print(f"Unpruned VGG-16 model has {orig_params} trainable parameters")

Unpruned VGG-16 model has 15240906 trainable parameters


In [35]:
for layer, param in model.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: features.0.weight & param.shape = torch.Size([64, 3, 3, 3])
layer.name: features.0.bias & param.shape = torch.Size([64])
layer.name: features.2.weight & param.shape = torch.Size([64, 64, 3, 3])
layer.name: features.2.bias & param.shape = torch.Size([64])
layer.name: features.5.weight & param.shape = torch.Size([128, 64, 3, 3])
layer.name: features.5.bias & param.shape = torch.Size([128])
layer.name: features.7.weight & param.shape = torch.Size([128, 128, 3, 3])
layer.name: features.7.bias & param.shape = torch.Size([128])
layer.name: features.10.weight & param.shape = torch.Size([256, 128, 3, 3])
layer.name: features.10.bias & param.shape = torch.Size([256])
layer.name: features.12.weight & param.shape = torch.Size([256, 256, 3, 3])
layer.name: features.12.bias & param.shape = torch.Size([256])
layer.name: features.14.weight & param.shape = torch.Size([256, 256, 3, 3])
layer.name: features.14.bias & param.shape = torch.Size([256])
layer.name: features.17.weight & param.shap

In [37]:
for layer_name in model.state_dict().keys():
    print(layer_name, model.state_dict()[layer_name].shape)

features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.2.weight torch.Size([64, 64, 3, 3])
features.2.bias torch.Size([64])
features.5.weight torch.Size([128, 64, 3, 3])
features.5.bias torch.Size([128])
features.7.weight torch.Size([128, 128, 3, 3])
features.7.bias torch.Size([128])
features.10.weight torch.Size([256, 128, 3, 3])
features.10.bias torch.Size([256])
features.12.weight torch.Size([256, 256, 3, 3])
features.12.bias torch.Size([256])
features.14.weight torch.Size([256, 256, 3, 3])
features.14.bias torch.Size([256])
features.17.weight torch.Size([512, 256, 3, 3])
features.17.bias torch.Size([512])
features.19.weight torch.Size([512, 512, 3, 3])
features.19.bias torch.Size([512])
features.21.weight torch.Size([512, 512, 3, 3])
features.21.bias torch.Size([512])
features.24.weight torch.Size([512, 512, 3, 3])
features.24.bias torch.Size([512])
features.26.weight torch.Size([512, 512, 3, 3])
features.26.bias torch.Size([512])
features.28.weight t

In [39]:
model.state_dict().keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

In [41]:
def compute_sparsity(model):
    conv1_sparsity = (torch.sum(model.features[0].weight == 0) / model.features[0].weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.features[2].weight == 0) / model.features[2].weight.nelement()) * 100
    conv3_sparsity = (torch.sum(model.features[5].weight == 0) / model.features[5].weight.nelement()) * 100
    conv4_sparsity = (torch.sum(model.features[7].weight == 0) / model.features[7].weight.nelement()) * 100
    conv5_sparsity = (torch.sum(model.features[10].weight == 0) / model.features[10].weight.nelement()) * 100
    conv6_sparsity = (torch.sum(model.features[12].weight == 0) / model.features[12].weight.nelement()) * 100
    conv7_sparsity = (torch.sum(model.features[14].weight == 0) / model.features[14].weight.nelement()) * 100
    conv8_sparsity = (torch.sum(model.features[17].weight == 0) / model.features[17].weight.nelement()) * 100
    conv9_sparsity = (torch.sum(model.features[19].weight == 0) / model.features[19].weight.nelement()) * 100
    conv10_sparsity = (torch.sum(model.features[21].weight == 0) / model.features[21].weight.nelement()) * 100
    conv11_sparsity = (torch.sum(model.features[24].weight == 0) / model.features[24].weight.nelement()) * 100
    conv12_sparsity = (torch.sum(model.features[26].weight == 0) / model.features[26].weight.nelement()) * 100
    conv13_sparsity = (torch.sum(model.features[28].weight == 0) / model.features[28].weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.classifier[1].weight == 0) / model.classifier[1].weight.nelement()) * 100
    fc2_sparsity = (torch.sum(model.classifier[4].weight == 0) / model.classifier[4].weight.nelement()) * 100
    op_sparsity = (torch.sum(model.classifier[6].weight == 0) / model.classifier[6].weight.nelement()) * 100

    num = torch.sum(model.features[0].weight == 0) + torch.sum(model.features[2].weight == 0) + torch.sum(model.features[5].weight == 0) + torch.sum(model.features[7].weight == 0) + torch.sum(model.features[10].weight == 0) + torch.sum(model.features[12].weight == 0) + torch.sum(model.features[14].weight == 0) + torch.sum(model.features[17].weight == 0) + torch.sum(model.features[19].weight == 0) + torch.sum(model.features[21].weight == 0)+ torch.sum(model.features[24].weight == 0) + torch.sum(model.features[26].weight == 0) + torch.sum(model.features[28].weight == 0) + torch.sum(model.classifier[1].weight == 0) + torch.sum(model.classifier[4].weight == 0) + torch.sum(model.classifier[6].weight == 0)
    denom = model.features[0].weight.nelement() + model.features[2].weight.nelement() + model.features[5].weight.nelement() + model.features[7].weight.nelement() + model.features[10].weight.nelement() + model.features[12].weight.nelement() + model.features[14].weight.nelement() + model.features[17].weight.nelement() + model.features[19].weight.nelement() + model.features[21].weight.nelement() + model.features[24].weight.nelement() + model.features[26].weight.nelement() + model.features[28].weight.nelement() + model.classifier[1].weight.nelement() + model.classifier[4].weight.nelement() + model.classifier[6].weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

In [43]:
print(f"VGG-16 global sparsity = {compute_sparsity(model):.2f}%")

VGG-16 global sparsity = 0.00%


In [59]:
for name, module in model.named_modules():
    print(name, module)

 VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=

In [53]:
import torch.nn.utils.prune as prune
for name, module in model.named_modules():
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Linear) and name != 'classifier.6':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear) and name == 'classifier.6':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)

In [55]:
print(f"VGG-16 global sparsity = {compute_sparsity(model):.2f}%")

VGG-16 global sparsity = 1.25%


In [63]:
new_params = count_params(model)
print(f"Unpruned VGG-16 model has {new_params} trainable parameters")

Unpruned VGG-16 model has 15240906 trainable parameters


In [67]:
import torch.nn.utils.prune as prune
from tqdm.notebook import tqdm
# torch.cuda.empty_cache()
# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model = model.to(device)
for epoch in tqdm(range(num_epochs)):
        # torch.cuda.empty_cache()
        # Training loop
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            
            model.train()
            
            y_pred = model(X)
            
            loss = loss_fn(y_pred, y)
            train_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            train_acc += acc
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)
            
        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)
                
                y_pred = model(X)
                
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                
                acc = accuracy(y_pred, y)
                val_acc += acc
                
            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)
        
        print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 0| Train loss:  2.12329| Train acc:  0.16485| Val loss:  1.92164| Val acc:  0.21387
Epoch: 1| Train loss:  1.88344| Train acc:  0.24189| Val loss:  1.75705| Val acc:  0.28730
Epoch: 2| Train loss:  1.69761| Train acc:  0.32361| Val loss:  1.62123| Val acc:  0.35977
Epoch: 3| Train loss:  1.57004| Train acc:  0.38664| Val loss:  1.49766| Val acc:  0.44238
Epoch: 4| Train loss:  1.41709| Train acc:  0.47229| Val loss:  1.27192| Val acc:  0.53809
Epoch: 5| Train loss:  1.20250| Train acc:  0.57385| Val loss:  1.02846| Val acc:  0.63965
Epoch: 6| Train loss:  1.06467| Train acc:  0.62943| Val loss:  1.04742| Val acc:  0.64395
Epoch: 7| Train loss:  0.98891| Train acc:  0.66015| Val loss:  0.89172| Val acc:  0.69863
Epoch: 8| Train loss:  0.88995| Train acc:  0.70130| Val loss:  1.01201| Val acc:  0.67246
Epoch: 9| Train loss:  0.84054| Train acc:  0.72211| Val loss:  0.88594| Val acc:  0.69648
Epoch: 10| Train loss:  0.77875| Train acc:  0.74454| Val loss:  0.76246| Val acc:  0.7439