<a href="https://colab.research.google.com/github/rrankawat/pytorch-cnn/blob/main/CIFAR_10_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.utils.prune as prune

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import time
import os

In [19]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [20]:
import os
os.listdir("/content/drive/My Drive/Colab Notebooks")

['wildfire-detection.ipynb',
 '1. Getting Started.ipynb',
 '2. Grayscaling Images.ipynb',
 '3. Color Spaces.ipynb',
 '4. Drawing on Images.ipynb',
 'Defect Analysis.ipynb',
 '01 Tensors.ipynb',
 '02 Tensor Operations.ipynb',
 '03 Tensor Math Operations.ipynb',
 '05 Convolutional Neural Network.ipynb',
 'FashionMnist (1).ipynb',
 '04 Neural Network.ipynb',
 'CIFAR-100.ipynb',
 'Mnist.ipynb',
 'model_fashion_mnist.pth',
 '__pycache__',
 'model_fashion_mnist.py',
 'FashionMnist.ipynb',
 'model_cifar10.py',
 'model_cifar10.pth',
 'CIFAR-10.ipynb',
 'CIFAR-10 Pruning.ipynb']

In [21]:
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)   # -> 16x32x32
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # -> 64x32x32
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # -> 64x32x32
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1) # -> 128x32x32
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128*2*2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)  # 32 -> 16

        # Block 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)  # 16 -> 8

        # Block 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)  # 8 -> 4

        # Block 4
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2, 2)  # 4 -> 2

        # Flatten
        x = x.view(-1, 128*2*2)

        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [22]:
# Load weights
model = CIFARConvNet()
model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/model_cifar10.pth"))
model.eval()

CIFARConvNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)

In [23]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [24]:
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

testloader = DataLoader(testset, batch_size=128, shuffle=False)

In [25]:
def test_accuracy(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            outputs = model(images)
            predicted = torch.max(outputs.data, 1)[1]
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total * 100

In [26]:
# Example: prune 30% of weights in conv and fc layers
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
)

for layer, param_name in parameters_to_prune:
    prune.l1_unstructured(layer, name=param_name, amount=0.3)

In [28]:
for name, module in model.named_modules():
  if hasattr(module, 'weight_mask'):
    sparsity = float(torch.sum(module.weight_mask == 0)) / float(module.weight_mask.nelement()) * 100
    print(f"Sparsity in {name}.weight: {sparsity:.2f}%")

Sparsity in conv1.weight: 30.09%
Sparsity in conv2.weight: 29.99%
Sparsity in fc1.weight: 30.00%


In [29]:
acc_before = test_accuracy(model, testloader)
print(f"Accuracy after pruning: {acc_before:.2f}%")

Accuracy after pruning: 70.08%
