# Pruning and Sparcity

Firstly Installing required packages and downloading the datasets and pretrained VGG model. I am using CIFAR 10 dataset and VGG network.

In [None]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('All required packages have been successfully installed!')


import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs

#device = 'mps'
device = 'cpu'
#device = 'cuda'
''' setting seeds to ensure reproducibility '''
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

Defining a misc function

In [None]:
def download_url(url, model_dir='.', overwrite=False):
    """
    Download a file from a URL to a specified directory.

    Args:
        url (str): The URL of the file to download.
        model_dir (str, optional): The directory where the file will be saved. Defaults to '.'.
        overwrite (bool, optional): If True, overwrite the file if it already exists. Defaults to False.

    Returns:
        str or None: The path to the downloaded file if successful, or None if the download fails.
    """
    import os, sys, ssl  # Import necessary libraries
    from urllib.request import urlretrieve  # Import urlretrieve function from urllib.request
    ssl._create_default_https_context = ssl._create_unverified_context  # Avoid SSL certificate verification

    # Extract the filename from the URL
    target_dir = url.split('/')[-1]

    # Expand the model directory to an absolute path
    model_dir = os.path.expanduser(model_dir)

    try:
        # Check if the model directory exists, if not, create it
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # Join the model directory with the target directory to get the full path
        model_dir = os.path.join(model_dir, target_dir)

        # Set the cached file path
        cached_file = model_dir

        # Check if the file does not exist or if overwrite is True
        if not os.path.exists(cached_file) or overwrite:
            # Print message indicating that the file is being downloaded
            sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))

            # Download the file from the URL and save it to the cached file path
            urlretrieve(url, cached_file)

        # Return the path to the downloaded file
        return cached_file

    except Exception as e:
        # If an exception occurs during the download process:
        # - Remove any lock file so that download can be attempted again next time.
        os.remove(os.path.join(model_dir, 'download.lock'))

        # Print an error message indicating the failure
        sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')

        # Return None to indicate failure
        return None

Defining my VGG net

In [None]:
class VGG(nn.Module):
  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))

    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    x = x.mean([2, 3])

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

In [None]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to device
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)

    # Calculating loss
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer and LR scheduler
    optimizer.step()
    scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [None]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to device
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

Defining helper functions

In [None]:
import torch.nn as nn
import torch

def get_model_macs(model, inputs) -> int:
    """
    Get the number of MACs (multiply-accumulate operations) for the given model with the provided inputs.

    Args:
        model: The PyTorch model.
        inputs: The input tensor or tensors to the model.

    Returns:
        int: The number of MACs.
    """
    return profile_macs(model, inputs)


def get_sparsity(tensor: torch.Tensor) -> float:
    """
    Calculate the sparsity of the given tensor.

    Sparsity is defined as the ratio of the number of zeros to the total number of elements.

    Args:
        tensor (torch.Tensor): The input tensor.

    Returns:
        float: The sparsity of the tensor.
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    Calculate the sparsity of the given model.

    Sparsity is defined as the ratio of the number of zeros to the total number of elements across all parameters in the model.

    Args:
        model (nn.Module): The PyTorch model.

    Returns:
        float: The sparsity of the model.
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    Calculate the total number of parameters in the model.

    Args:
        model (nn.Module): The PyTorch model.
        count_nonzero_only (bool, optional): If True, count only nonzero weights. Defaults to False.

    Returns:
        int: The total number of parameters in the model.
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    Calculate the model size in bits.

    Args:
        model (nn.Module): The PyTorch model.
        data_width (int, optional): The number of bits per element. Defaults to 32.
        count_nonzero_only (bool, optional): If True, count only nonzero weights. Defaults to False.

    Returns:
        int: The size of the model in bits.
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

# Constants for byte multiples
Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

### Real Shit

Loading the pre-trained model and CIFAR 10 dataset

In [None]:
checkpoint_url = "https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth"
checkpoint = torch.load(download_url(checkpoint_url), map_location="cpu")

model = VGG().to(device)
model.load_state_dict(checkpoint['state_dict'])

# Define a recovery function so I can play around with the model and still recover it to its original state.
recover_model = lambda: model.load_state_dict(checkpoint['state_dict'])

In [None]:
image_size = 32

# Define transformations for training and testing datasets
transforms = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
    ]),
    "test": ToTensor(),
}

# Initializing empty dictionary to store datasets
dataset = {}

# Loop over train and test splits
for split in ["train", "test"]:
    # Create CIFAR10 dataset with specified root, train/test split, download if necessary, and apply transformations
    dataset[split] = CIFAR10(
        root="data/cifar10",
        train=(split == "train"),  # Set train=True for training split, train=False for test split
        download=True,  # Download the dataset if it's not found in the root directory
        transform=transforms[split],
    )

# Initializing empty dictionary to store data loaders
dataloader = {}

# Loop over train and test splits
for split in ['train', 'test']:
    # Create DataLoader for the corresponding dataset split
    dataloader[split] = DataLoader(
        dataset[split],  # Use the dataset for this split
        batch_size=512,
        shuffle=(split == 'train'),
        num_workers=0,
        pin_memory=True,  # Pin memory for faster data transfer to device
    )

### Evaluating pre-trained model

In [None]:
dense_model_accuracy = evaluate(model, dataloader['test'])
dense_model_size = get_model_size(model)
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")
print(f"dense model has size={dense_model_size/MiB:.2f} MiB")

### Before pruning , lets take a look at the weight distribution in the current model.

In [None]:
def plot_weight_distribution(model, bins=256, count_nonzero_only=False):
    fig, axes = plt.subplots(3,3, figsize=(10, 6))
    axes = axes.ravel()
    plot_index = 0
    for name, param in model.named_parameters():
        if param.dim() > 1:
            ax = axes[plot_index]
            if count_nonzero_only:
                param_cpu = param.detach().view(-1).cpu()
                param_cpu = param_cpu[param_cpu != 0].view(-1)
                ax.hist(param_cpu, bins=bins, density=True,
                        color = 'black', alpha = 0.5)
            else:
                ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True,color = 'black', alpha = 0.5)
                # param_cpu = param.detach().view(-1).cpu()
                # total_weights = len(param_cpu)
                # param_cpu = param_cpu[(param_cpu <= 0.01) & (param_cpu >= -0.01)].view(-1)
                # zero_weights = len(param_cpu)
                # print(f'Number of zero parameters in {name} is {zero_weights}/{total_weights}')
            ax.set_xlabel(name)
            ax.set_ylabel('density')
            plot_index += 1
        # else :
        #     print(f'Skipping {name}')
    fig.suptitle('Histogram of Weights')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()

plot_weight_distribution(model)
#plot_weight_distribution(model,count_nonzero_only=True)

The weight distribution plot clearly indicates that a substantial portion of the model's weights are close to zero. This suggests that these weights may not contribute significantly to the model's performance. As a result, there's a significant potential to reduce the model's size by pruning these unnecessary weights. This could lead to improved efficiency and make the model more suitable for deployment on devices with limited resources.

### Prune Time

Sparsity = #(zeros)/#(elements) for some weight.

To prune a model to some sparsity, find the number of non-zeros (say k) to keep and remove everything below the kth smallest value.

Sensitivity Scan = set sparsity to x for some layer and analyze the accuracy of the resultant model, then set sparsity to (say) x + 0.1 and repeat.

#### Fine Grained Prune

In [None]:
def fine_grained_prune(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    # Ensure sparsity is within [0.0, 1.0]
    sparsity = min(max(0.0, sparsity), 1.0)

    # If sparsity is 1.0, zero out the tensor and return a tensor of zeros
    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    # If sparsity is 0.0, return a tensor of ones
    elif sparsity == 0.0:
        return torch.ones_like(tensor)

    # Calculate the total number of elements in the tensor
    num_elements = tensor.numel()

    # Calculate the number of zeros to retain based on sparsity
    num_zeros = round(sparsity * num_elements)

    # Compute the importance of each weight by taking its absolute value
    importance = torch.abs(tensor)

    # Find the threshold value that separates important weights from non-important ones
    threshold = torch.kthvalue(input=importance.flatten(), k=num_zeros).values

    # Generate a binary mask where 1 represents important weights and 0 represents non-important ones
    mask = torch.gt(importance, threshold)

    # Apply the mask to the tensor to prune non-important weights
    tensor.mul_(mask)

    # Return the binary mask indicating which weights were pruned
    return mask

Testing and visualizing the fine_grained_prune method.

In [None]:
def plot_matrix(tensor, ax, title):
    """
    Display a grid representing the tensor, where zero values are white and non-zero values are blue.

    Args:
    - tensor: Input tensor to visualize
    - ax: Axis object to plot on
    - title: Title for the plot
    """
    ax.imshow(tensor.cpu().numpy() == 0, cmap='tab20c')
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    for i in range(tensor.shape[1]):
            for j in range(tensor.shape[0]):
                text = ax.text(j, i, f'{tensor[i, j].item():.2f}',
                                ha="center", va="center", color="k")

torch.manual_seed(42)
random_tensor = torch.randn(6, 6)
fig, axes = plt.subplots(1,2, figsize=(6, 10))
ax_left,ax_right = axes.ravel()
plot_matrix(random_tensor, ax_left, "Dense Tensor")

pruned_tensor = random_tensor*fine_grained_prune(tensor=random_tensor,sparsity=0.6)
plot_matrix(pruned_tensor, ax_right, "Sparse Tensor")


# Show the plot
plt.show()

Wrapping fine_grained_prune function into a class for pruning the whole model

In [None]:
class FineGrainedPruner:
    """
    A class for fine-grained pruning of PyTorch model weights based on a given sparsity dictionary.

    Attributes:
        masks (dict): A dictionary containing masks for pruning model parameters.

    Methods:
        __init__(model, sparsity_dict):
            Initializes the pruner object by computing masks for pruning based on the provided model and sparsity dictionary.

        apply(model):
            Applies the computed masks to the weights of the given model.

        prune(model, sparsity_dict):
            Static method to compute masks for pruning based on the provided model and sparsity dictionary.
    """

    def __init__(self, model, sparsity_dict):
        """
        Initializes the pruner object by computing masks for pruning based on the provided model and sparsity dictionary.

        Args:
            model (torch.nn.Module): The PyTorch model to be pruned.
            sparsity_dict (dict): A dictionary containing sparsity levels for different layers of the model.
        """
        self.masks = FineGrainedPruner.prune(model, sparsity_dict)

    @torch.no_grad()
    def apply(self, model):
        """
        Applies the computed masks to the weights of the given model.

        Args:
            model (torch.nn.Module): The PyTorch model to which the masks are applied.
        """
        for name, param in model.named_parameters():
            if name in self.masks:
                param *= self.masks[name]

    @staticmethod
    @torch.no_grad()
    def prune(model, sparsity_dict):
        """
        Static method to compute masks for pruning based on the provided model and sparsity dictionary.

        Args:
            model (torch.nn.Module): The PyTorch model to be pruned.
            sparsity_dict (dict): A dictionary containing sparsity levels for different layers of the model.

        Returns:
            dict: A dictionary containing masks for pruning model parameters.
        """
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 1: # we only prune conv and fc weights
                masks[name] = fine_grained_prune(param, sparsity_dict[name])
        return masks

### Sensitivity Scan

The sensitivity scan is a method used to see how pruning a single layer of a neural network affects its accuracy. We go through each layer of the network and start pruning it and graphing its accuracy.

By doing this for each layer and tracking the accuracy as we go, we can figure out which parts of the network are really important for its accuracy and which parts we can prune without causing too much trouble. It's basically about finding the balance between making the network smaller and maintaining accuracy.

Defining sensitivity_scan function that will prune every layer with increasing sparsities and measure the accuracy of the model after that.

In [None]:
@torch.no_grad()
def sensitivity_scan(model, dataloader, scan_step=0.05, scan_start=0.3, scan_end=1.0, verbose=True):
    sparsities = np.arange(start=scan_start, stop=scan_end, step=scan_step)
    accuracies = []
    named_conv_weights = [(name, param) for (name, param) \
                          in model.named_parameters() if param.dim() > 1]
    for i_layer, (name, param) in enumerate(named_conv_weights):
        param_clone = param.detach().clone()
        accuracy = []
        for sparsity in tqdm(sparsities, desc=f'scanning {i_layer}/{len(named_conv_weights)} weight - {name}'):
            fine_grained_prune(param.detach(), sparsity=sparsity)
            acc = evaluate(model, dataloader, verbose=False)
            if verbose:
                print(f'\r    sparsity={sparsity:.2f}: accuracy={acc:.2f}%', end='')
            # restore
            param.copy_(param_clone)
            accuracy.append(acc)
        if verbose:
            print(f'\r    sparsity=[{",".join(["{:.2f}".format(x) for x in sparsities])}]: accuracy=[{", ".join(["{:.2f}%".format(x) for x in accuracy])}]', end='')
        accuracies.append(accuracy)
    return sparsities, accuracies

Execute the sensitivity scan

In [None]:
sparsities, accuracies = sensitivity_scan(
    model, dataloader['test'], scan_step=0.1, scan_start=0.4, scan_end=1.0)

### Plotting Sensitivity Scan

In [None]:
def plot_sensitivity_scan(sparsities, accuracies, dense_model_accuracy):
    """
    Plot the sensitivity scan results showing how accuracy varies with pruning sparsity.

    Args:
    - sparsities (numpy.ndarray): Array of pruning sparsities.
    - accuracies (list of lists): List of accuracy values for each layer at different sparsity levels.
    - dense_model_accuracy (float): Accuracy of the original dense model without pruning.
    """
    # Calculate lower bound accuracy to indicate significant performance drop
    lower_bound_accuracy = 100 - (100 - dense_model_accuracy) * 1.5

    # Create subplots to display sensitivity curves
    fig, axes = plt.subplots(3, int(math.ceil(len(accuracies) / 3)), figsize=(15, 8))
    axes = axes.ravel()  # Flatten the axes for easier indexing
    plot_index = 0

    # Iterate over named parameters of the model
    for name, param in model.named_parameters():
        # Check if parameter dimension is greater than 1 (indicating weight parameters)
        if param.dim() > 1:
            ax = axes[plot_index]  # Select current axis for plotting

            # Plot accuracy curve after pruning for the current layer
            curve = ax.plot(sparsities, accuracies[plot_index])

            # Plot lower bound accuracy line to indicate significant performance drop
            line = ax.plot(sparsities, [lower_bound_accuracy] * len(sparsities))

            # Set x-axis ticks and y-axis limits
            ax.set_xticks(np.arange(start=0.4, stop=1.0, step=0.1))
            ax.set_ylim(80, 95)

            # Set title, labels, and legend
            ax.set_title(name)
            ax.set_xlabel('sparsity')
            ax.set_ylabel('top-1 accuracy')
            ax.legend([
                'accuracy after pruning',
                f'{lower_bound_accuracy / dense_model_accuracy * 100:.0f}% of dense model accuracy'
            ])

            # Add gridlines along the x-axis
            ax.grid(axis='x')

            plot_index += 1  # Move to the next subplot

    # Add title and adjust layout for better spacing
    fig.suptitle('Sensitivity Curves: Validation Accuracy vs. Pruning Sparsity')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)

    # Show the plot
    plt.show()

Besides checking how much pruning affects a layer's performance, it's also crucial to know how many parameters are in that layer.

In [None]:
def plot_num_parameters_distribution(model):
    num_parameters = dict()
    for name, param in model.named_parameters():
        if param.dim() > 1:
            num_parameters[name] = param.numel()
    fig = plt.figure(figsize=(8, 6))
    plt.grid(axis='y')
    plt.bar(list(num_parameters.keys()), list(num_parameters.values()))
    plt.title('#Parameter Distribution')
    plt.ylabel('Number of Parameters')
    plt.xticks(rotation=60)
    plt.tight_layout()
    plt.show()

plot_num_parameters_distribution(model)

### Experimentation Time

Now that all the tools are defined , its time to define sparsity values .

In [None]:
recover_model() #Every time I experiment , I need to recover model to its original state.

sparsity_dict = {
    'backbone.conv0.weight': 0,
    'backbone.conv1.weight': 0.7,
    'backbone.conv2.weight': 0.8,
    'backbone.conv3.weight': 0.7,
    'backbone.conv4.weight': 0.7,
    'backbone.conv5.weight': 0.8,
    'backbone.conv6.weight': 0.8,
    'backbone.conv7.weight': 0.9,
    'classifier.weight': 0.95
}


pruner = FineGrainedPruner(model, sparsity_dict)
print(f'After pruning with sparsity dictionary')
for name, sparsity in sparsity_dict.items():
    print(f'  {name}: {sparsity:.2f}')
print(f'The sparsity of each layer becomes')
for name, param in model.named_parameters():
    if name in sparsity_dict:
        print(f'  {name}: {get_sparsity(param):.2f}')

sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = evaluate(model, dataloader['test'])
print(f"Sparse model has accuracy={sparse_model_accuracy:.2f}% before fintuning")

plot_weight_distribution(model, count_nonzero_only=True)

### The pruned model's accuracy is pretty low, so we need to fine-tune it just like we normally train a model. However, at each epoch during fine-tuning, we have to apply the pruning mask again, or else the model's size will go back to what it was before pruning.

In [None]:
num_finetune_epochs = 4
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_sparse_model_checkpoint = dict()
best_accuracy = 0
print(f'Finetuning Fine-grained Pruned Sparse Model')
for epoch in range(num_finetune_epochs):
    # At the end of each train iteration, we have to apply the pruning mask
    #    to keep the model sparse during the training
    train(model, dataloader['train'], criterion, optimizer, scheduler,
          callbacks=[lambda: pruner.apply(model)])
    accuracy = evaluate(model, dataloader['test'])
    is_best = accuracy > best_accuracy
    if is_best:
        best_sparse_model_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    print(f'    Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

Lastly , results of the experiment.

In [None]:
# load the best sparse model checkpoint to evaluate the final performance
model.load_state_dict(best_sparse_model_checkpoint['state_dict'])
sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = evaluate(model, dataloader['test'])
print(f"Sparse model has accuracy={sparse_model_accuracy:.2f}% after fintuning")