In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import prune
import os
import copy

# Step 0: Setup - Check for GPU and set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Step 1: Load the Dataset (MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Step 2: Define the CNN Model (a simple LeNet-style model)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2) # 1 input channel (grayscale)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 120) # 7x7 is the image size after two pooling layers
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 10) # 10 output classes for digits 0-9

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7) # Flatten the tensor
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# Helper function to train the model
def train_model(model, train_loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    for epoch in range(epochs):
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# Helper function to test the model
def test_model(model, test_loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Helper function to count non-zero parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Helper function to save model and get its size
def save_and_get_size(model, filename):
    torch.save(model.state_dict(), filename)
    size_mb = os.path.getsize(filename) / (1024 * 1024)
    os.remove(filename) # Clean up the file
    return size_mb

# --- Main Workflow ---

print("\n--- 1. Training the Original (Baseline) Model ---")
original_model = SimpleCNN().to(device)
train_model(original_model, train_loader, epochs=5)
original_accuracy = test_model(original_model, test_loader)
original_params = count_parameters(original_model)
original_size_mb = save_and_get_size(original_model, "original.pth")

print("\n--- 2. Pruning the Model ---")
# Create a deep copy to prune, so we don't modify the original
pruned_model = copy.deepcopy(original_model)

# Define which layers and parameters to prune
parameters_to_prune = (
    (pruned_model.conv1, 'weight'),
    (pruned_model.conv2, 'weight'),
    (pruned_model.fc1, 'weight'),
    (pruned_model.fc2, 'weight'),
)

# Apply pruning globally
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # Prune 40% of the connections
)

# Check sparsity (percentage of zeroed-out weights)
total_params = 0
zero_params = 0
for module, param_name in parameters_to_prune:
    total_params += getattr(module, param_name).nelement()
    zero_params += torch.sum(getattr(module, param_name) == 0)

print(f"Global Sparsity: {100. * float(zero_params) / total_params:.2f}%")

# Test accuracy right after pruning (before fine-tuning)
accuracy_after_pruning = test_model(pruned_model, test_loader)
print(f"Accuracy after pruning (before fine-tuning): {accuracy_after_pruning:.2f}%")


print("\n--- 3. Fine-tuning the Pruned Model ---")
train_model(pruned_model, train_loader, epochs=3) # Train for fewer epochs


print("\n--- 4. Final Comparison --- 📉")
# Make the pruning permanent before counting parameters and saving
for module, param_name in parameters_to_prune:
    prune.remove(module, param_name)

pruned_accuracy = test_model(pruned_model, test_loader)
pruned_params = count_parameters(pruned_model)
pruned_size_mb = save_and_get_size(pruned_model, "pruned.pth")

# --- Print the final results table ---
print("\n==================== PROJECT RESULTS ====================")
print(f"| Metric                | Original Model | Pruned Model   | Reduction      |")
print(f"|-----------------------|----------------|----------------|----------------|")
print(f"| Accuracy (%)          | {original_accuracy:>11.2f}    | {pruned_accuracy:>11.2f}     | {'N/A':<14} |")
print(f"| Parameters (Total)    | {original_params:>11,}     | {pruned_params:>11,}     | {100*(1-pruned_params/original_params):>11.2f}%    |")
print(f"| Model Size (MB)       | {original_size_mb:>11.4f}    | {pruned_size_mb:>11.4f}     | {100*(1-pruned_size_mb/original_size_mb):>11.2f}%    |")
print("=========================================================")

Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 483kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.53MB/s]



--- 1. Training the Original (Baseline) Model ---
Epoch [1/5], Loss: 0.0138
Epoch [2/5], Loss: 0.0206
Epoch [3/5], Loss: 0.0012
Epoch [4/5], Loss: 0.0893
Epoch [5/5], Loss: 0.0010

--- 2. Pruning the Model ---
Global Sparsity: 40.00%
Accuracy after pruning (before fine-tuning): 98.87%

--- 3. Fine-tuning the Pruned Model ---
Epoch [1/3], Loss: 0.0067
Epoch [2/3], Loss: 0.0006
Epoch [3/3], Loss: 0.0015

--- 4. Final Comparison --- 📉

| Metric                | Original Model | Pruned Model   | Reduction      |
|-----------------------|----------------|----------------|----------------|
| Accuracy (%)          |       98.87    |       99.12     | N/A            |
| Parameters (Total)    |     202,738     |     202,738     |        0.00%    |
| Model Size (MB)       |      0.7769    |      0.7765     |        0.05%    |
