<a href="https://colab.research.google.com/github/raki-rankawat/stm32/blob/main/CIFAR10_Pruned_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Load Trained CIFAR10 Model

In [16]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
# Data Loaders
batch_size = 64

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [19]:
# CNN Model
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 32 -> 16

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 16 -> 8

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 8 -> 4

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 4 -> 2

        x = x.view(x.size(0), -1) # Flatten

        # x = self.dropout(x)
        x = F.relu(self.fc1(x))
        if self.training:
          x = self.dropout(x)
        x = self.fc2(x)

        return x

In [20]:
# Load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CIFARConvNet()
model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_model.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [21]:
# Accuracy Before Pruning
def test_accuracy_full(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total

acc_base = test_accuracy_full(model, test_loader)
print(f"✅ PyTorch FULL test accuracy (BASE): {acc_base:.2f}%")

✅ PyTorch FULL test accuracy (BASE): 80.04%


### Pruning

In [22]:
PRUNE_AMOUNT = 0.15          # 10% | 20% | 30%
PRUNE_TYPE = "structured"    # "structured" | "unstructured"

In [23]:
layers_to_prune = [
    (model.conv2, "weight"),
    (model.conv3, "weight"),
    (model.conv4, "weight"),
]

In [24]:
# Apply pruning
if PRUNE_TYPE == "structured":
    for layer, param in layers_to_prune:
        prune.ln_structured(layer, name=param, amount=PRUNE_AMOUNT, n=2, dim=0)
    print(f"✅ Structured pruning: {PRUNE_AMOUNT*100:.0f}% filters on conv2/conv3/conv4")
else:
    for layer, param in layers_to_prune:
        prune.l1_unstructured(layer, name=param, amount=PRUNE_AMOUNT)
    print(f"✅ Unstructured pruning: {PRUNE_AMOUNT*100:.0f}% weights on conv2/conv3/conv4")

✅ Structured pruning: 15% filters on conv2/conv3/conv4


In [25]:
# Accuracy After Pruning
acc_after_prune = test_accuracy_full(model, test_loader)
print(f"✅ PyTorch FULL test accuracy (AFTER PRUNE, before FT): {acc_after_prune:.2f}%")

✅ PyTorch FULL test accuracy (AFTER PRUNE, before FT): 58.37%


### Fine-Tune

In [26]:
FT_EPOCHS = 3
FT_LR = 1e-4
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=FT_LR)

In [27]:
start_time = time.time()

for epoch in range(1, FT_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        out = model(x)
        loss = criterion(out, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    print(f"FT Epoch {epoch}/{FT_EPOCHS} | Train Loss: {running_loss/total:.4f} | Train Acc: {100*correct/total:.2f}%")

print(f"Fine-tune time: {(time.time() - start_time)/60:.2f} minutes")

FT Epoch 1/3 | Train Loss: 0.6645 | Train Acc: 77.07%
FT Epoch 2/3 | Train Loss: 0.6048 | Train Acc: 78.94%
FT Epoch 3/3 | Train Loss: 0.5832 | Train Acc: 79.72%
Fine-tune time: 5.73 minutes


In [28]:
# Accuracy After Fine-Tune
acc_after_ft = test_accuracy_full(model, test_loader)
print(f"✅ PyTorch FULL test accuracy (AFTER FT): {acc_after_ft:.2f}%")

✅ PyTorch FULL test accuracy (AFTER FT): 79.63%


In [29]:
# Make pruning permanent before saving/exporting
for layer, param in layers_to_prune:
    prune.remove(layer, param)

PRUNED_FT_PTH = "/content/drive/My Drive/Colab Notebooks/stm_cifar10_pruned_model.pth"
torch.save(model.state_dict(), PRUNED_FT_PTH)
print("✅ Saved pruned+finetuned weights:", PRUNED_FT_PTH)

✅ Saved pruned+finetuned weights: /content/drive/My Drive/Colab Notebooks/stm_cifar10_pruned_model.pth
