In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import timm
import copy
import os
import tempfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("msambare/fer2013")

print("Path to dataset files:", path)

In [None]:
# Dataset loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = ImageFolder('fer2013/versions/1/train', transform=transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
test_dataset = ImageFolder('fer2013/versions/1/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check distribution
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
from torchvision.models.quantization import resnet18 as resnet18_model

resnet18 = resnet18_model(pretrained=True, quantize=False)
resnet18.fc = nn.Linear(512, 7)
resnet18.eval()
resnet18.to(device)

In [None]:
# Unstructured Pruning for CNN
def unstructured_prune_cnn(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Pruning for CNN
def structured_prune_cnn(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
            prune.remove(module, "weight")  # REQUIRED!
    return model

In [None]:
# Unstructured Pruning for ViT
def unstructured_prune_vit(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Attention Head Pruning for ViT
def prune_vit_attention_heads(model, heads_to_prune=2):
    for name, module in model.named_modules():
        if hasattr(module, 'qkv') and hasattr(module, 'num_heads'):
            heads_dim = module.qkv.weight.shape[0] // 3
            head_size = heads_dim // module.num_heads
            qkv_weights = module.qkv.weight.data.view(3, module.num_heads, head_size, -1)
            norms = qkv_weights.norm(dim=(2, 3))
            importance = norms.sum(dim=0)
            prune_indices = torch.topk(importance, heads_to_prune, largest=False).indices
            for i in prune_indices:
                qkv_weights[:, i, :, :] = 0
            module.qkv.weight.data = qkv_weights.view(-1, module.qkv.weight.shape[1])
    return model

In [None]:
def quantize_model(model, calibration_loader):
    model.eval()
    model.cpu()
    model.fuse_model()

    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)

    with torch.no_grad():
        for images, _ in calibration_loader:
            images = images.to(torch.float32).cpu()
            model(images)

    torch.quantization.convert(model, inplace=True)
    return model

def quantize_vgg_dynamic(model):
    return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)


def quantize_vit_dynamic(model):
    return torch.quantization.quantize_dynamic(model.to("cpu"), {nn.Linear}, dtype=torch.qint8)

In [None]:
import time

def train(model, train_loader, val_loader, epochs=10, lr=1e-4):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss += loss.item()
                correct += (outputs.argmax(1) == y).sum().item()
                total += y.size(0)

        train_loss = running_loss / len(train_loader)
        val_loss /= len(val_loader)
        val_acc = 100 * correct / total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.2f}%")

In [None]:
def evaluate(model, test_loader, quantization=False):
    model.eval()

    if quantization:
        model = model.to("cpu")  # Quantized models must be on CPU
    else:
        model = model.to("cuda")

    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            if quantization:
                x, y = x.to("cpu"), y.to("cpu")
            else:
                x, y = x.to("cuda"), y.to("cuda")

            outputs = model(x)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    accuracy = 100 * correct / total
    print(f"Test Accuracy = {accuracy:.2f}%")
    return accuracy

In [None]:
def measure_inference_speed(model, test_loader, quantization=False):
    model.eval()
    
    if quantization:
        device = "cpu"
        
    else:
        device = torch.device("cuda")
    
    model.to(device)
    
    start = time.time()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            _ = model(x)
    end = time.time()
    latency = (end - start) / len(test_loader)
    print(f"Avg Inference Time per Batch: {latency:.4f} sec")
    return latency

def model_size_mb(model, use_state_dict=True):
    with tempfile.NamedTemporaryFile(delete=False) as f:
        if use_state_dict:
            torch.save(model.state_dict(), f.name)
        else:
            torch.save(model, f.name)
        size_mb = os.path.getsize(f.name) / (1024 * 1024)
    print(f"Model Size ({'state_dict' if use_state_dict else 'full model'}): {size_mb:.2f} MB")
    return size_mb


In [None]:
def quantize_pruned_model(model, dataloader):
    import torch
    import torch.quantization

    # Make sure model is in eval mode and on CPU
    model.eval()
    model.cpu()

    # Fuse top-level modules
    fused_model = torch.quantization.fuse_modules(
        model,
        [["conv1", "bn1", "relu"]],
        inplace=False
    )

    # Fuse residual blocks
    for name, module in fused_model.named_children():
        if "layer" in name:
            for block in module:
                torch.quantization.fuse_modules(
                    block,
                    [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
                    inplace=True
                )

    # Set qconfig and prepare for calibration
    fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(fused_model, inplace=True)

    # Calibrate
    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            if i >= 10:
                break
            images = images.to(torch.float32).cpu()
            fused_model(images)

    # Convert to quantized model
    quantized_model = torch.quantization.convert(fused_model, inplace=False)
    return quantized_model

In [None]:
# === Apply quantization on baseline (no pruning) ===
print("=== Apply quantization on baseline (no pruning) ===")
resnet18_quant = quantize_model(resnet18, train_loader)
evaluate(resnet18_quant, test_loader, quantization=True)
model_size_mb(resnet18_quant)
measure_inference_speed(resnet18_quant, test_loader, quantization=True)

# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
resnet18_quant_pr = quantize_pruned_model(resnet18_pruned_st, train_loader)
evaluate(resnet18_quant_pr, test_loader, quantization=True)
model_size_mb(resnet18_quant_pr)
measure_inference_speed(resnet18_quant_pr, test_loader, quantization=True)

In [None]:
import torch

def quantize_trained_pruned_model(model, dataloader, num_calibration_batches=10):
    """
    Applies static quantization to a pruned + trained ResNet18 model.
    - Fuses modules
    - Prepares with qconfig
    - Calibrates using dataloader
    - Converts to quantized model
    """
    model.eval()
    model.cpu()

    # Fuse top-level layers
    fused_model = torch.quantization.fuse_modules(
        model,
        [["conv1", "bn1", "relu"]],
        inplace=False
    )

    # Fuse residual blocks
    for name, module in fused_model.named_children():
        if "layer" in name:
            for block in module:
                torch.quantization.fuse_modules(
                    block,
                    [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
                    inplace=True
                )

    # Attach quantization config
    fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

    # Prepare for calibration
    torch.quantization.prepare(fused_model, inplace=True)

    # Calibrate on a few batches
    fused_model.eval()
    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            if i >= num_calibration_batches:
                break
            images = images.to(torch.float32).cpu()
            fused_model(images)

    # Convert to quantized model
    quantized_model = torch.quantization.convert(fused_model, inplace=False)
    return quantized_model


In [None]:
# === Train baseline ===
print("=== Train baseline ===")
train(resnet18, train_loader, val_loader, epochs=1)
evaluate(resnet18, test_loader)
model_size_mb(resnet18)
measure_inference_speed(resnet18, test_loader)

# === Apply structured pruning, then fine-tune ===
print("=== Apply structured pruning, then fine-tune ===")
resnet18_pruned_st = structured_prune_cnn(resnet18, amount=0.5)
train(resnet18_pruned_st, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(resnet18_pruned_st, test_loader)
model_size_mb(resnet18_pruned_st)
measure_inference_speed(resnet18_pruned_st, test_loader)

# === Apply unstructured pruning, then fine-tune ===
print("=== Apply unstructured pruning, then fine-tune ===")
resnet18_pruned_unst = unstructured_prune_cnn(resnet18, amount=0.5)
train(resnet18_pruned_unst, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(resnet18_pruned_unst, test_loader)
model_size_mb(resnet18_pruned_unst)
measure_inference_speed(resnet18_pruned_unst, test_loader)

# === Apply quantization on baseline (no pruning) ===
print("=== Apply quantization on baseline (no pruning) ===")
resnet18_quant = quantize_model(resnet18, train_loader)
evaluate(resnet18_quant, test_loader, quantization=True)
model_size_mb(resnet18_quant)
measure_inference_speed(resnet18_quant, test_loader, quantization=True)

# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
resnet18_quant_pr = quantize_trained_pruned_model(resnet18_pruned_st, train_loader)
evaluate(resnet18_quant_pr, test_loader, quantization=True)
model_size_mb(resnet18_quant_pr)
measure_inference_speed(resnet18_quant_pr, test_loader, quantization=True)

In [None]:
# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
resnet18_quant_pr = quantize_pruned_model(resnet18_pruned_st, train_loader)
evaluate(resnet18_quant_pr, test_loader, quantization=True)
model_size_mb(resnet18_quant_pr)
measure_inference_speed(resnet18_quant_pr, test_loader, quantization=True)

In [None]:
print("=== Apply quantization on baseline (no pruning) ===")
resnet18_quant = quantize_model(resnet18, train_loader)
#evaluate(resnet18_quant, test_loader, quantization=True)
model_size_mb(resnet18_quant)

In [None]:
print("=== Quantizing trained pruned model ===")
resnet18_quant_pr = quantize_trained_pruned_model(resnet18_pruned_st, train_loader)
#evaluate(resnet18_quant_pr, test_loader, quantization=True)
model_size_mb(resnet18_quant_pr, use_state_dict=False)
#measure_inference_speed(resnet18_quant_pr, test_loader, quantization=True)
