# 07. Pruning Study

This notebook explores network pruning techniques to reduce model size.

## Experiment Overview
- **Goal**: Explore network pruning techniques
- **Model**: Prunable MLP with magnitude-based pruning
- **Features**: Pruning schedules, sparsity analysis, accuracy vs. compression
- **Learning**: Understanding model compression and efficiency

## What You'll Learn
- Network pruning techniques
- Magnitude-based pruning
- Pruning schedules
- Accuracy vs. compression trade-offs


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import load_mnist_data, get_device, set_seed

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Load MNIST dataset
print("Loading MNIST dataset...")
train_loader, val_loader, test_loader = load_mnist_data(batch_size=64, test_split=0.2)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


In [None]:
# Define a prunable MLP model
class PrunableMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super(PrunableMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Create model instance
model = PrunableMLP().to(device)

# Print model summary
print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024:.2f} MB")

# Train the model first
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
    """Train the model before pruning."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    val_losses = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    
    return train_losses, val_losses

# Train the model
print("Training model before pruning...")
train_losses, val_losses = train_model(model, train_loader, val_loader, epochs=10)

# Plot training history
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training History (Before Pruning)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training History (Log Scale)')
plt.xlabel('Epoch')
plt.ylabel('Loss (log)')
plt.yscale('log')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('../results/plots/pruning_training.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Magnitude-based pruning implementation
def magnitude_pruning(model, sparsity):
    """Apply magnitude-based pruning to the model."""
    pruned_model = PrunableMLP().to(device)
    pruned_model.load_state_dict(model.state_dict())
    
    # Get all weights
    all_weights = []
    for name, param in pruned_model.named_parameters():
        if 'weight' in name:
            all_weights.extend(param.data.abs().flatten().cpu().numpy())
    
    # Calculate threshold
    threshold = np.percentile(all_weights, sparsity * 100)
    
    # Prune weights
    for name, param in pruned_model.named_parameters():
        if 'weight' in name:
            mask = param.data.abs() > threshold
            param.data *= mask.float()
    
    return pruned_model

# Evaluate model accuracy
def evaluate_accuracy(model, data_loader):
    """Evaluate model accuracy."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    return 100. * correct / total

# Calculate sparsity
def calculate_sparsity(model):
    """Calculate model sparsity."""
    total_params = 0
    zero_params = 0
    
    for param in model.parameters():
        total_params += param.numel()
        zero_params += (param == 0).sum().item()
    
    return zero_params / total_params

# Test different sparsity levels
sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
accuracies = []
sparsities = []

print("Testing different sparsity levels...")
for sparsity in sparsity_levels:
    # Create pruned model
    pruned_model = magnitude_pruning(model, sparsity)
    
    # Calculate actual sparsity
    actual_sparsity = calculate_sparsity(pruned_model)
    
    # Evaluate accuracy
    accuracy = evaluate_accuracy(pruned_model, test_loader)
    
    accuracies.append(accuracy)
    sparsities.append(actual_sparsity)
    
    print(f"Sparsity: {actual_sparsity:.2%}, Accuracy: {accuracy:.2f}%")

# Plot results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(sparsities, accuracies, 'bo-', linewidth=2, markersize=8)
plt.xlabel('Sparsity')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy vs Sparsity')
plt.grid(True)
plt.axhline(y=accuracies[0], color='r', linestyle='--', alpha=0.7, label='Original Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
compression_ratios = [1 / (1 - s) for s in sparsities]
plt.plot(compression_ratios, accuracies, 'go-', linewidth=2, markersize=8)
plt.xlabel('Compression Ratio')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy vs Compression Ratio')
plt.grid(True)
plt.axhline(y=accuracies[0], color='r', linestyle='--', alpha=0.7, label='Original Accuracy')
plt.legend()

plt.tight_layout()
plt.savefig('../results/plots/pruning_results.png', dpi=300, bbox_inches='tight')
plt.show()

# Print summary
print(f"\nPruning Summary:")
print(f"Original accuracy: {accuracies[0]:.2f}%")
print(f"Best pruned accuracy: {max(accuracies[1:]):.2f}%")
print(f"Accuracy drop at 50% sparsity: {accuracies[0] - accuracies[5]:.2f}%")
print(f"Accuracy drop at 80% sparsity: {accuracies[0] - accuracies[8]:.2f}%")
