In [1]:
# Imports
import torch
from torch import nn  # All neural network modules
from torch import optim  # For optimizers like SGD, Adam, etc.
from torch.utils.data import Dataset, DataLoader  # Gives easier dataset managment
import torchvision
import torchvision.datasets as datasets  # Standard datasets
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
import math
import os
import random
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

In [None]:
def set_seed(seed=27):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

# Create a device variable which will be used to shift model and data to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(seed=27)
device

### 1. Download CIFAR10 dataset from torchvision

In [None]:
train_dataset = datasets.CIFAR10(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.CIFAR10(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)

print(len(train_dataset)) 
print(len(test_dataset)) 

In [4]:
# Hyperparameters 
batch_size = 64
num_classes = 10
learning_rate = 0.01
num_epochs = 25

In [5]:
# Making dataset iterable
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size)

### 2. Creating a Convolutional Neural Network (CNN)

In [6]:
# Creating a CNN class
class ConvNeuralNet(nn.Module):
	#  Determine what layers and their order in CNN object 
    def __init__(self, num_classes):
        super(ConvNeuralNet, self).__init__()
        self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.conv_layer3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.fc1 = nn.Linear(1600, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)
        
        #Initialize the mask tensors for each layer
        layers = ['conv_layer1', 'conv_layer2', 'conv_layer3','conv_layer4', 'fc1', 'fc2']
        self.masks = []
        for layer in layers:
            self.masks.append(torch.nn.Parameter(torch.Tensor(torch.ones_like(self.__getattr__(layer).weight))))
    
    # Progresses data across layers    
    def forward(self, x):
        
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.bn1(out)
        out = self.max_pool1(out)
        
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
        out = self.bn2(out)
        out = self.max_pool2(out)
        
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

### Model training routine for each round


In [7]:
def train_model(epochs, dataloader, device, model, optimizer, loss_function):
    # Set model to training mode in order to unfreeze all layers and allow gradient propagation
    model.train()
        
    # These two lists will be used to store average loss and accuracy for each epoch
    total_loss, acc = list(), list()
    
    for epoch in range(epochs):
        print("Epoch:", epoch+1)
        
        # Each batch produces a loss, predictions and target
        batch_loss, batch_preds, batch_target = 0, list(), list()
        
        #Load in the data in batches using the train_loader object
        for i, (images, labels) in enumerate(tqdm(train_loader)):          
            # Move tensors to the configured device
            images, labels = images.to(device), labels.to(device)            
            # Remove all previous gradients
            optimizer.zero_grad()            
            # Get predictions by performing a forward pass
            preds = model(images)            
            # Calculate error
            loss = loss_function(preds, labels)            
            # Calculate all the gradients for each layer
            loss.backward()            
            # Finall, update the weights
            optimizer.step()             
            # Save the loss
            batch_loss+= loss.item()            
            # Save the predictions and target
            batch_preds.extend(np.argmax(preds.cpu().detach().numpy(), axis=1))
            batch_target.extend(labels.cpu().detach().numpy())        
        # Calculate average loss
        total_loss.append(batch_loss/len(dataloader))   
    return model, total_loss


def test_model(dataloader, device, model):  
    # Set model to eval mode in order to freeze all layers so that no parameter gets updated during testing
    model.eval()
    with torch.no_grad():
        correct, total = 0, 0        
        for images, labels in tqdm(train_loader):        
            # Make sure that data is on the same device as the model
            images, labels = images.to(device), labels.to(device)            
            preds = model(images)        
            _, predicted = torch.max(preds.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy= correct / total

    return accuracy

### 3/4 Run the CNN for 25 epochs to get baseline results.

In [None]:
# Set the number of epochs to be used
epochs = 25
# Create the model
model = ConvNeuralNet(num_classes).to(device)
# Define Loss
loss_function = nn.CrossEntropyLoss()
# Set optimizer with optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 
model, _= train_model(epochs, train_loader, device, model, optimizer, loss_function)
print("Accuracy: ", test_model(test_loader, device, model))

### Defingin a Common function for pruning

In [241]:
def pruning(epochs, rounds, sparsities, pruningType):
        
    # Create the model
    model = ConvNeuralNet(num_classes).to(device)
    # Define Loss
    loss_function = nn.CrossEntropyLoss()
    # Set optimizer with optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
 
    # Store the validation accuracy for each sparsity level
    val_accuracies = []

    # First save the model weights that have been initialized
    init_weights = [model.__getattr__(layer).weight.data.to(device) for layer in layers]

    for sparsity in sparsities:
        print("\n\n\nFor Sparsity: ", sparsity, "\n-----------------------")    

        # Reset the model weights and masks
        for i, layer in enumerate(layers):
            model.__getattr__(layer).weight.data.copy_(init_weights[i])
            model.masks[i] = torch.ones_like(model.masks[i])

        for round_ in range(rounds):
            print("\n\nROUND", round_+1, "Started\n----------------------")

            #First train the model for some epochs
            model, _ = train_model(epochs, train_loader, device, model, optimizer, loss_function)

            with torch.no_grad():
                # Now prune the model weights
                for i, layer in enumerate(layers):
                    if isinstance(layer, nn.Conv2d):
                        num_filters = layer.weight.data.shape[0]
                        num_channels = layer.weight.data.shape[1]
                        filter_size = layer.weight.data.shape[2] * layer.weight.data.shape[3]
                        
                        if(pruningType == 'LTH'):
                            # Lottery Ticket Style Pruning
                            indices = torch.argsort(torch.reshape(torch.abs(layer.weight.data), (1, num_filters*num_channels*filter_size)).squeeze())
                        elif(pruningType == 'rnd'):
                            # Random pruning
                            indices = np.random.randint(0, num_channels*num_filters*filter_size, num_channels*num_filters*filter_size)
                            
                        # Since we already have the indices to prune, first reset the parameters
                        model.__getattr__(layer).weight.copy_(init_weights[i])

                        # Now prune
                        model.masks[i] = torch.reshape(model.masks[i], (1, num_filters*num_channels*filter_size)).squeeze()
                        val = ((sparsity*100)**((round_+1)/rounds))/100
                        model.masks[i][indices[:math.ceil(val*num_filters*num_channels*filter_size)]] = 0
                        model.masks[i] = torch.reshape(torch.reshape(model.masks[i], (1, num_filters*num_channels*filter_size)).squeeze(), (num_filters, num_channels, layer.weight.data.shape[2], layer.weight.data.shape[3]))

            # Compute the validation accuracy
            val_acc = test_model(test_loader, device, model)
            val_accuracies.append(val_acc)
            print("Test Accuracy: ", val_acc)
    return val_accuracies

### Performing Lottery Ticket pruning 
   

In [None]:
epochs = 5
rounds = 5
sparsities = [0.1, 0.2, 0.3, 0.4, 0.5]

accLTH = pruning(epochs, rounds, sparsities, 'LTH')
accLTH

### Performing random pruning 


In [None]:
epochs = 5
rounds = 5
sparsities = [0.1, 0.2, 0.3, 0.4, 0.5]

accRnd = pruning(epochs, rounds, sparsities, 'rnd')
accRnd

### 5. Plot a graph of sparsity vs validation accuracy

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.lineplot(x=sparsities, y=accLTH, label="LTH")
sns.lineplot(x=sparsities, y=accRnd, label="random")
plt.xlabel("Sparsity(%)")
plt.ylabel("Validation Accuracy")
sns.despine()

### 6. The Lottery Ticket style pruning using different hyper-parameters



In [None]:
rounds = [1, 2, 3, 4, 5]
epochs = [30, 15, 10, 7, 6]
sparcities = [0.5]

for i, round_ in enumerate(rounds):
    acc = pruning(epochs[i], round_, sparsities, 'LTH')
    sns.lineplot(x=round_, y=acc, label="LTH")

plt.ylabel("Validation Accuracy")
sns.despine()

### 7. Initializing wts randomly after pruning 

In [None]:
epochs = 5
rounds = 5
sparsity = 0.5

# Create the model
model = ConvNeuralNet(num_classes).to(device)
# Define Loss
loss_function = nn.CrossEntropyLoss()
# Set optimizer with optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
 
# Store the validation accuracy for each sparsity level
val_accuracies = []

# First save the model weights that have been initialized
init_weights = [model.__getattr__(layer).weight.data.to(device) for layer in layers]

for round_ in range(rounds):
    print("\n\nROUND", round_+1, "Started\n----------------------")

    #First train the model for some epochs
#     model, _ = train_model(epochs, train_loader, device, model, optimizer, loss_function)

    with torch.no_grad():
        # Now prune the model weights
        for i, layer in enumerate(layers):
            if isinstance(layer, nn.Conv2d):
                num_filters = layer.weight.data.shape[0]
                num_channels = layer.weight.data.shape[1]
                filter_size = layer.weight.data.shape[2] * layer.weight.data.shape[3]
                        
                # Lottery Ticket Style Pruning
                indices = torch.argsort(torch.reshape(torch.abs(layer.weight.data), (1, num_filters*num_channels*filter_size)).squeeze())
                
                # Since we already have the indices to prune, first reset the parameters
                model.__getattr__(layer).weight.copy_(init_weights[i])

                # Now prune
                model.masks[i] = torch.reshape(model.masks[i], (1, num_filters*num_channels*filter_size)).squeeze()
                val = ((sparsity*100)**((round_+1)/rounds))/100
                model.masks[i][indices[:math.ceil(val*num_filters*num_channels*filter_size)]] = 0
                model.masks[i] = torch.reshape(torch.reshape(model.masks[i], (1, num_filters*num_channels*filter_size)).squeeze(), (num_filters, num_channels, layer.weight.data.shape[2], layer.weight.data.shape[3]))
                
                # Re-initialize the pruned weights randomly
                torch.nn.init.kaiming_uniform_(model.__getattr__(layer).weight.data[model.masks[i] == 1])
    
    # Compute the validation accuracy
    val_acc = test_model(test_loader, device, model)
    val_accuracies.append(val_acc)
    print("Test Accuracy: ", val_acc)
