Pip Installs

In [None]:
!pip install torchmetrics
!pip install cleverhans
!pip install quantus
!pip install captum
!pip install ranger-adabelief

Imports

In [None]:
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune
from torch.utils.data import random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets
from torchvision import transforms
import pandas as pd
from PIL import Image

from torchmetrics import Accuracy
import torch.optim as optim
from cleverhans.torch.attacks.projected_gradient_descent import (projected_gradient_descent)

import quantus
import captum
from captum.attr import Saliency, IntegratedGradients, NoiseTunnel

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import random
import copy
import gc
import math

import warnings

import os
from itertools import chain

from collections import Counter
from ranger_adabelief import RangerAdaBelief

Class Formation

In [None]:
def class_formation():
  classes = {
    0: 'Speed limit (20km/h)',
    1: 'Speed limit (30km/h)',
    2: 'Speed limit (50km/h)',
    3: 'Speed limit (60km/h)',
    4: 'Speed limit (70km/h)',
    5: 'Speed limit (80km/h)',
    6: 'End of speed limit (80km/h)',
    7: 'Speed limit (100km/h)',
    8: 'Speed limit (120km/h)',
    9: 'No passing',
    10: 'No passing veh over 3.5 tons',
    11: 'Right-of-way at intersection',
    12: 'Priority road',
    13: 'Yield',
    14: 'Stop',
    15: 'No vehicles',
    16: 'Veh > 3.5 tons prohibited',
    17: 'No entry',
    18: 'General caution',
    19: 'Dangerous curve left',
    20: 'Dangerous curve right',
    21: 'Double curve',
    22: 'Bumpy road',
    23: 'Slippery road',
    24: 'Road narrows on the right',
    25: 'Road work',
    26: 'Traffic signals',
    27: 'Pedestrians',
    28: 'Children crossing',
    29: 'Bicycles crossing',
    30: 'Beware of ice/snow',
    31: 'Wild animals crossing',
    32: 'End speed + passing limits',
    33: 'Turn right ahead',
    34: 'Turn left ahead',
    35: 'Ahead only',
    36: 'Go straight or right',
    37: 'Go straight or left',
    38: 'Keep right',
    39: 'Keep left',
    40: 'Roundabout mandatory',
    41: 'End of no passing',
    42: 'End no passing veh > 3.5 tons',
  }

  class_names = list(classes.values())
  return classes, class_names

Model Eval

In [None]:
def evaluate_model(model, dataloader, device):
    """
    This function evaluates the model using test dataset
    """
    model.eval()
    prediction = torch.Tensor().to(device)
    labels = torch.LongTensor().to(device)

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            prediction = torch.cat([prediction, model(x_batch)])
            labels = torch.cat([labels, y_batch])

    # passing the logits through Softmax layer to get predicted class
    prediction = torch.nn.functional.softmax(prediction, dim=1)

    return prediction, labels

Sparisty Functions

In [None]:
def compute_sparsity_vgg(model):

    conv1_sparsity = (torch.sum(model.features[0].weight == 0) / model.features[0].weight.nelement()) * 100
    conv2_sparsity = (torch.sum(model.features[2].weight == 0) / model.features[2].weight.nelement()) * 100
    conv3_sparsity = (torch.sum(model.features[5].weight == 0) / model.features[5].weight.nelement()) * 100
    conv4_sparsity = (torch.sum(model.features[7].weight == 0) / model.features[7].weight.nelement()) * 100
    conv5_sparsity = (torch.sum(model.features[10].weight == 0) / model.features[10].weight.nelement()) * 100
    conv6_sparsity = (torch.sum(model.features[12].weight == 0) / model.features[12].weight.nelement()) * 100
    conv7_sparsity = (torch.sum(model.features[14].weight == 0) / model.features[14].weight.nelement()) * 100
    conv8_sparsity = (torch.sum(model.features[17].weight == 0) / model.features[17].weight.nelement()) * 100
    conv9_sparsity = (torch.sum(model.features[19].weight == 0) / model.features[19].weight.nelement()) * 100
    conv10_sparsity = (torch.sum(model.features[21].weight == 0) / model.features[21].weight.nelement()) * 100
    conv11_sparsity = (torch.sum(model.features[24].weight == 0) / model.features[24].weight.nelement()) * 100
    conv12_sparsity = (torch.sum(model.features[26].weight == 0) / model.features[26].weight.nelement()) * 100
    conv13_sparsity = (torch.sum(model.features[28].weight == 0) / model.features[28].weight.nelement()) * 100
    fc1_sparsity = (torch.sum(model.classifier[1].weight == 0) / model.classifier[1].weight.nelement()) * 100
    fc2_sparsity = (torch.sum(model.classifier[4].weight == 0) / model.classifier[4].weight.nelement()) * 100
    op_sparsity = (torch.sum(model.classifier[6].weight == 0) / model.classifier[6].weight.nelement()) * 100

    num = torch.sum(model.features[0].weight == 0) + torch.sum(model.features[2].weight == 0) + torch.sum(model.features[5].weight == 0) + torch.sum(model.features[7].weight == 0) + torch.sum(model.features[10].weight == 0) + torch.sum(model.features[12].weight == 0) + torch.sum(model.features[14].weight == 0) + torch.sum(model.features[17].weight == 0) + torch.sum(model.features[19].weight == 0) + torch.sum(model.features[21].weight == 0)+ torch.sum(model.features[24].weight == 0) + torch.sum(model.features[26].weight == 0) + torch.sum(model.features[28].weight == 0) + torch.sum(model.classifier[1].weight == 0) + torch.sum(model.classifier[4].weight == 0) + torch.sum(model.classifier[6].weight == 0)
    denom = model.features[0].weight.nelement() + model.features[2].weight.nelement() + model.features[5].weight.nelement() + model.features[7].weight.nelement() + model.features[10].weight.nelement() + model.features[12].weight.nelement() + model.features[14].weight.nelement() + model.features[17].weight.nelement() + model.features[19].weight.nelement() + model.features[21].weight.nelement() + model.features[24].weight.nelement() + model.features[26].weight.nelement() + model.features[28].weight.nelement() + model.classifier[1].weight.nelement() + model.classifier[4].weight.nelement() + model.classifier[6].weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

In [None]:
def compute_sparsity_resnet(model):

    conv0_sparsity = (torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()) * 100
    bn0_sparsity = (torch.sum(model.bn1.weight == 0) / model.bn1.weight.nelement()) * 100

    conv1_sparsity = (torch.sum(model.layer1[0].conv1.weight == 0) / model.layer1[0].conv1.weight.nelement()) * 100
    bn1_sparsity = (torch.sum(model.layer1[0].bn1.weight == 0) / model.layer1[0].bn1.weight.nelement()) * 100

    conv2_sparsity = (torch.sum(model.layer1[0].conv2.weight == 0) / model.layer1[0].conv2.weight.nelement()) * 100
    bn2_sparsity = (torch.sum(model.layer1[0].bn2.weight == 0) / model.layer1[0].bn2.weight.nelement()) * 100

    conv3_sparsity = (torch.sum(model.layer1[1].conv1.weight == 0) / model.layer1[1].conv1.weight.nelement()) * 100
    bn3_sparsity = (torch.sum(model.layer1[1].bn1.weight == 0) / model.layer1[1].bn1.weight.nelement()) * 100

    conv4_sparsity = (torch.sum(model.layer1[1].conv2.weight == 0) / model.layer1[1].conv2.weight.nelement()) * 100
    bn4_sparsity = (torch.sum(model.layer1[1].bn2.weight == 0) / model.layer1[1].bn2.weight.nelement()) * 100

    conv5_sparsity = (torch.sum(model.layer2[0].conv1.weight == 0) / model.layer2[0].conv1.weight.nelement()) * 100
    bn5_sparsity = (torch.sum(model.layer2[0].bn1.weight == 0) / model.layer2[0].bn1.weight.nelement()) * 100

    conv6_sparsity = (torch.sum(model.layer2[0].conv2.weight == 0) / model.layer2[0].conv2.weight.nelement()) * 100
    bn6_sparsity = (torch.sum(model.layer2[0].bn2.weight == 0) / model.layer2[0].bn2.weight.nelement()) * 100

    conv7_sparsity = (torch.sum(model.layer2[1].conv1.weight == 0) / model.layer2[1].conv1.weight.nelement()) * 100
    bn7_sparsity = (torch.sum(model.layer2[1].bn1.weight == 0) / model.layer2[1].bn1.weight.nelement()) * 100

    conv8_sparsity = (torch.sum(model.layer2[1].conv2.weight == 0) / model.layer2[1].conv2.weight.nelement()) * 100
    bn8_sparsity = (torch.sum(model.layer2[1].bn2.weight == 0) / model.layer2[1].bn2.weight.nelement()) * 100

    conv9_sparsity = (torch.sum(model.layer3[0].conv1.weight == 0) / model.layer3[0].conv1.weight.nelement()) * 100
    bn9_sparsity = (torch.sum(model.layer3[0].bn1.weight == 0) / model.layer3[0].bn1.weight.nelement()) * 100

    conv10_sparsity = (torch.sum(model.layer3[0].conv2.weight == 0) / model.layer3[0].conv2.weight.nelement()) * 100
    bn10_sparsity = (torch.sum(model.layer3[0].bn2.weight == 0) / model.layer3[0].bn2.weight.nelement()) * 100

    conv11_sparsity = (torch.sum(model.layer3[1].conv1.weight == 0) / model.layer3[1].conv1.weight.nelement()) * 100
    bn11_sparsity = (torch.sum(model.layer3[1].bn1.weight == 0) / model.layer3[1].bn1.weight.nelement()) * 100

    conv12_sparsity = (torch.sum(model.layer3[1].conv2.weight == 0) / model.layer3[1].conv2.weight.nelement()) * 100
    bn12_sparsity = (torch.sum(model.layer3[1].bn2.weight == 0) / model.layer3[1].bn2.weight.nelement()) * 100

    conv13_sparsity = (torch.sum(model.layer4[0].conv1.weight == 0) / model.layer4[0].conv1.weight.nelement()) * 100
    bn13_sparsity = (torch.sum(model.layer4[0].bn1.weight == 0) / model.layer4[0].bn1.weight.nelement()) * 100

    conv14_sparsity = (torch.sum(model.layer4[0].conv2.weight == 0) / model.layer4[0].conv2.weight.nelement()) * 100
    bn14_sparsity = (torch.sum(model.layer4[0].bn2.weight == 0) / model.layer4[0].bn2.weight.nelement()) * 100

    conv15_sparsity = (torch.sum(model.layer4[1].conv1.weight == 0) / model.layer4[1].conv1.weight.nelement()) * 100
    bn15_sparsity = (torch.sum(model.layer4[1].bn1.weight == 0) / model.layer4[1].bn1.weight.nelement()) * 100

    conv16_sparsity = (torch.sum(model.layer4[1].conv2.weight == 0) / model.layer4[1].conv2.weight.nelement()) * 100
    bn16_sparsity = (torch.sum(model.layer4[1].bn2.weight == 0) / model.layer4[1].bn2.weight.nelement()) * 100

    fc_sparsity = (torch.sum(model.fc.weight == 0) / model.fc.weight.nelement()) * 100

    num =  torch.sum(model.conv1.weight == 0) + torch.sum(model.bn1.weight == 0) + torch.sum(model.layer1[0].conv1.weight == 0) + torch.sum(model.layer1[0].bn1.weight == 0) + torch.sum(model.layer1[0].conv2.weight == 0) +  torch.sum(model.layer1[0].bn2.weight == 0) + torch.sum(model.layer1[1].conv1.weight == 0) +  torch.sum(model.layer1[1].bn1.weight == 0) + torch.sum(model.layer1[1].conv2.weight == 0) + torch.sum(model.layer1[1].bn2.weight == 0) +torch.sum(model.layer2[0].conv1.weight == 0) + torch.sum(model.layer2[0].bn1.weight == 0) + torch.sum(model.layer2[0].conv2.weight == 0) +  torch.sum(model.layer2[0].bn2.weight == 0) + torch.sum(model.layer2[1].conv1.weight == 0) + torch.sum(model.layer2[1].bn1.weight == 0) + torch.sum(model.layer2[1].conv2.weight == 0) + torch.sum(model.layer2[1].bn2.weight == 0) + torch.sum(model.layer3[0].conv1.weight == 0) + torch.sum(model.layer3[0].bn1.weight == 0) + torch.sum(model.layer3[0].conv2.weight == 0) +  torch.sum(model.layer3[0].bn2.weight == 0) + torch.sum(model.layer3[1].conv1.weight == 0) +  torch.sum(model.layer3[1].bn1.weight == 0) + torch.sum(model.layer3[1].conv2.weight == 0) + torch.sum(model.layer3[1].bn2.weight == 0) + torch.sum(model.layer4[0].conv1.weight == 0) + torch.sum(model.layer4[0].bn1.weight == 0) + torch.sum(model.layer4[0].conv2.weight == 0) +  torch.sum(model.layer4[0].bn2.weight == 0) + torch.sum(model.layer4[1].conv1.weight == 0) +  torch.sum(model.layer4[1].bn1.weight == 0) + torch.sum(model.layer4[1].conv2.weight == 0) + torch.sum(model.layer4[1].bn2.weight == 0) + torch.sum(model.fc.weight == 0)

    denom =  model.conv1.weight.nelement() +  model.bn1.weight.nelement() + model.layer1[0].conv1.weight.nelement() + model.layer1[0].bn1.weight.nelement() + model.layer1[0].conv2.weight.nelement() + model.layer1[0].bn2.weight.nelement() + model.layer1[1].conv1.weight.nelement() +  model.layer1[1].bn1.weight.nelement() + model.layer1[1].conv2.weight.nelement() + model.layer1[1].bn2.weight.nelement() +  model.layer2[0].conv1.weight.nelement() + model.layer2[0].bn1.weight.nelement() + model.layer2[0].conv2.weight.nelement() + model.layer2[0].bn2.weight.nelement() + model.layer2[1].conv1.weight.nelement() +  model.layer2[1].bn1.weight.nelement() + model.layer2[1].conv2.weight.nelement() + model.layer2[1].bn2.weight.nelement() +  model.layer3[0].conv1.weight.nelement() + model.layer3[0].bn1.weight.nelement() + model.layer3[0].conv2.weight.nelement() + model.layer3[0].bn2.weight.nelement() + model.layer3[1].conv1.weight.nelement() +  model.layer3[1].bn1.weight.nelement() + model.layer3[1].conv2.weight.nelement() + model.layer3[1].bn2.weight.nelement() +  model.layer4[0].conv1.weight.nelement() + model.layer4[0].bn1.weight.nelement() + model.layer4[0].conv2.weight.nelement() + model.layer4[0].bn2.weight.nelement() + model.layer4[1].conv1.weight.nelement() +  model.layer4[1].bn1.weight.nelement() + model.layer4[1].conv2.weight.nelement() + model.layer4[1].bn2.weight.nelement() + model.fc.weight.nelement()
    global_sparsity = num/denom * 100
    return global_sparsity

Pruning Functions

L1 Unstructured Pruning

In [None]:
def l1unstructured_prune(input_model):
  for name, module in input_model.named_modules():
    # 20% of weights/connections pruned for all hidden layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)

    # 10% of weights/connections pruned for output layer
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)
  return input_model

Global Pruning - ResNet

In [None]:
def global_prune_resnet(input_model)
  parameters_to_prune = (
    (input_model.conv1, 'weight'),
    (input_model.bn1, 'weight'),
    (input_model.layer1[0].conv1, 'weight'),
    (input_model.layer1[0].bn1, 'weight'),
    (input_model.layer1[0].conv2, 'weight'),
    (input_model.layer1[0].bn2, 'weight'),
    (input_model.layer1[1].conv1, 'weight'),
    (input_model.layer1[1].bn1, 'weight'),
    (input_model.layer1[1].conv2, 'weight'),
    (input_model.layer1[1].bn2, 'weight'),
    (input_model.layer2[0].conv1, 'weight'),
    (input_model.layer2[0].bn1, 'weight'),
    (input_model.layer2[0].conv2, 'weight'),
    (input_model.layer2[0].bn2, 'weight'),
    (input_model.layer2[1].conv1, 'weight'),
    (input_model.layer2[1].bn1, 'weight'),
    (input_model.layer2[1].conv2, 'weight'),
    (input_model.layer2[1].bn2, 'weight'),
    (input_model.layer3[0].conv1, 'weight'),
    (input_model.layer3[0].bn1, 'weight'),
    (input_model.layer3[0].conv2, 'weight'),
    (input_model.layer3[0].bn2, 'weight'),
    (input_model.layer3[1].conv1, 'weight'),
    (input_model.layer3[1].bn1, 'weight'),
    (input_model.layer3[1].conv2, 'weight'),
    (input_model.layer3[1].bn2, 'weight'),
    (input_model.layer4[0].conv1, 'weight'),
    (input_model.layer4[0].bn1, 'weight'),
    (input_model.layer4[0].conv2, 'weight'),
    (input_model.layer4[0].bn2, 'weight'),
    (input_model.layer4[1].conv1, 'weight'),
    (input_model.layer4[1].bn1, 'weight'),
    (input_model.layer4[1].conv2, 'weight'),
    (input_model.layer4[1].bn2, 'weight'),
    (input_model.fc, 'weight')
    )

  prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]
  for iter_prune_round in range(1):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")

    # Prune layer-wise in a structured manner-
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method = prune.L1Unstructured,
        amount = prune_rates_global[iter_prune_round]
    )
  return input_model

Global Pruning - VGG

In [None]:
# Define the parameters to prune for the VGG16 model
def global_prune_vgg(input_model):
  parameters_to_prune = (
      (input_model.features[0], 'weight'),
      (input_model.features[2], 'weight'),
      (input_model.features[5], 'weight'),
      (input_model.features[7], 'weight'),
      (input_model.features[10], 'weight'),
      (input_model.features[12], 'weight'),
      (input_model.features[14], 'weight'),
      (input_model.features[17], 'weight'),
      (input_model.features[19], 'weight'),
      (input_model.features[21], 'weight'),
      (input_model.features[24], 'weight'),
      (input_model.features[26], 'weight'),
      (input_model.features[28], 'weight'),
      (input_model.classifier[1], 'weight'),
      (input_model.classifier[4], 'weight'),
      (input_model.classifier[6], 'weight')
  )

# Define pruning rates
  prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]
  for iter_prune_round in range(1):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")

    # Prune layer-wise in a structured manner
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=prune_rates_global[iter_prune_round]
    )
  return input_model

Layered Structured Pruning - ResNet

In [None]:
# Prune layer-wise in a structured manner-
def layeredstructured_prune_resnet(input_model):
  prune.ln_structured(input_model.conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer1[0].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer1[0].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer1[1].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer1[1].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer2[0].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer2[0].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer2[1].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer2[1].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer3[0].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer3[0].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer3[1].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer3[1].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer4[0].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer4[0].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer4[1].conv1, name = "weight", amount = 0.1, n = 2, dim = 0)
  prune.ln_structured(input_model.layer4[1].conv2, name = "weight", amount = 0.1, n = 2, dim = 0)

  prune.ln_structured(input_model.fc, name = "weight", amount = 0.1, n = 2, dim = 0)

  return input_model

Linear(in_features=512, out_features=43, bias=True)

Layered Structured Pruning - VGG

In [None]:
def layeredstructured_prune_vgg(input_model):

  prune.ln_structured(input_model.features[0], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[2], name="weight", amount=0.1, n=2, dim=0)

  prune.ln_structured(input_model.features[5], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[7], name="weight", amount=0.1, n=2, dim=0)

  prune.ln_structured(input_model.features[10], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[12], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[14], name="weight", amount=0.1, n=2, dim=0)

  prune.ln_structured(input_model.features[17], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[19], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[21], name="weight", amount=0.1, n=2, dim=0)

  prune.ln_structured(input_model.features[24], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[26], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.features[28], name="weight", amount=0.1, n=2, dim=0)

  prune.ln_structured(input_model.classifier[1], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.classifier[4], name="weight", amount=0.1, n=2, dim=0)
  prune.ln_structured(input_model.classifier[6], name="weight", amount=0.1, n=2, dim=0)

  return input_model

Linear(in_features=512, out_features=43, bias=True)