In [3]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import timm
import copy
import os
import tempfile
from torch.ao.quantization import get_default_qconfig, prepare, convert
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.qconfig import QConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("msambare/fer2013")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/msambare/fer2013?dataset_version_number=1...


100%|██████████| 60.3M/60.3M [00:00<00:00, 90.8MB/s]

Extracting files...





Path to dataset files: /home/smahadi/.cache/kagglehub/datasets/msambare/fer2013/versions/1


In [5]:
# Dataset loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = ImageFolder('fer2013/versions/1/train', transform=transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
test_dataset = ImageFolder('fer2013/versions/1/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check distribution
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Train samples: 22968
Validation samples: 5741
Test samples: 7178


In [6]:
from torchvision.models.quantization import resnet18 as resnet18_model

resnet18 = resnet18_model(pretrained=True, quantize=False)
resnet18.fc = nn.Linear(512, 7)
resnet18.eval()
resnet18.to(device)

QuantizableResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): QuantizableBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (add_relu): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (1): QuantizableBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, e

In [7]:
# Unstructured Pruning for CNN
def unstructured_prune_cnn(model, amount=0.3):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Pruning for CNN
def structured_prune_cnn(model, amount=0.5):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
            prune.remove(module, "weight")
    return model

In [8]:
def fuse_model_blocks(model):
    torch.quantization.fuse_modules(model, [["conv1", "bn1", "relu"]], inplace=True)
    for module_name, module in model.named_children():
        if "layer" in module_name:
            for block in module:
                torch.quantization.fuse_modules(
                    block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
                    inplace=True
                )
                if hasattr(block, "downsample") and isinstance(block.downsample, torch.nn.Sequential):
                    if len(block.downsample) >= 2:
                        torch.quantization.fuse_modules(block.downsample, ["0", "1"], inplace=True)

In [9]:
def quantize_trained_pruned_model(model, calibration_loader, num_batches=10):
    import copy
    model = copy.deepcopy(model)
    model.cpu().eval()

    # Fuse layers (must happen after pruning and training)
    fuse_model_blocks(model)

    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)

    with torch.no_grad():
        for i, (x, _) in enumerate(calibration_loader):
            x = x.to(torch.float32)
            x = x * 0.5 + 0.5  # Undo normalization: [0, 1]
            x = torch.clamp(x, 0.0, 1.0)
            model(x)
            if i >= num_batches:
                break

    torch.quantization.convert(model, inplace=True)
    return model

In [10]:
def quantize_model(model, calibration_loader, num_batches=10):
    import copy
    model = copy.deepcopy(model)
    model.eval()
    model.cpu()

    # Fuse layers
    fuse_model_blocks(model)

    # Set quantization config and prepare
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)

    # Calibration loop
    with torch.no_grad():
        for i, (x, _) in enumerate(calibration_loader):
            x = x.to(torch.float32)
            x = x * 0.5 + 0.5  # Undo Normalize([0.5], [0.5])
            x = torch.clamp(x, 0.0, 1.0)  # Ensure values in [0, 1]
            model(x)
            if i >= num_batches:
                break

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    return model

In [11]:
import time

def train(model, train_loader, val_loader, epochs=10, lr=1e-4):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss += loss.item()
                correct += (outputs.argmax(1) == y).sum().item()
                total += y.size(0)

        train_loss = running_loss / len(train_loader)
        val_loss /= len(val_loader)
        val_acc = 100 * correct / total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.2f}%")

In [12]:
def evaluate(model, test_loader, quantization=False):
    model.eval()

    if quantization:
        model = model.to("cpu")  # Quantized models must be on CPU
    else:
        model = model.to("cuda")

    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            if quantization:
                x, y = x.to("cpu"), y.to("cpu")
            else:
                x, y = x.to("cuda"), y.to("cuda")

            outputs = model(x)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    accuracy = 100 * correct / total
    print(f"Test Accuracy = {accuracy:.2f}%")
    return accuracy

In [13]:
def measure_inference_speed(model, test_loader, quantization=False):
    model.eval()
    
    if quantization:
        device = "cpu"
        
    else:
        device = torch.device("cuda")
    
    model.to(device)
    
    start = time.time()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            _ = model(x)
    end = time.time()
    latency = (end - start) / len(test_loader)
    print(f"Avg Inference Time per Batch: {latency:.4f} sec")
    return latency

def model_size_mb(model, use_state_dict=True):
    with tempfile.NamedTemporaryFile(delete=False) as f:
        if use_state_dict:
            torch.save(model.state_dict(), f.name)
        else:
            torch.save(model, f.name)
        size_mb = os.path.getsize(f.name) / (1024 * 1024)
    print(f"Model Size ({'state_dict' if use_state_dict else 'full model'}): {size_mb:.2f} MB")
    return size_mb

In [14]:
# === Train baseline ===
print("=== Train baseline ===")
train(resnet18, train_loader, val_loader, epochs=1)
evaluate(resnet18, test_loader)
model_size_mb(resnet18)
measure_inference_speed(resnet18, test_loader)

# === Apply structured pruning, then fine-tune ===
print("=== Apply structured pruning, then fine-tune ===")
resnet18_pruned_st = structured_prune_cnn(resnet18, amount=0.5)
train(resnet18_pruned_st, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(resnet18_pruned_st, test_loader)
model_size_mb(resnet18_pruned_st)
measure_inference_speed(resnet18_pruned_st, test_loader)

# === Apply unstructured pruning, then fine-tune ===
print("=== Apply unstructured pruning, then fine-tune ===")
resnet18_pruned_unst = unstructured_prune_cnn(resnet18, amount=0.5)
train(resnet18_pruned_unst, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(resnet18_pruned_unst, test_loader)
model_size_mb(resnet18_pruned_unst)
measure_inference_speed(resnet18_pruned_unst, test_loader)

=== Train baseline ===
Epoch 1: Train Loss = 1.1643, Val Loss = 1.0048, Val Acc = 61.99%
Test Accuracy = 62.58%
Model Size (state_dict): 42.73 MB
Avg Inference Time per Batch: 0.1355 sec
=== Apply structured pruning, then fine-tune ===
Epoch 1: Train Loss = 1.1940, Val Loss = 1.0761, Val Acc = 59.05%
Test Accuracy = 58.58%
Model Size (state_dict): 42.73 MB
Avg Inference Time per Batch: 0.1286 sec
=== Apply unstructured pruning, then fine-tune ===
Epoch 1: Train Loss = 0.8098, Val Loss = 0.9449, Val Acc = 64.95%
Test Accuracy = 64.61%
Model Size (state_dict): 42.73 MB
Avg Inference Time per Batch: 0.1330 sec


0.13296038703580873

In [15]:
# === Apply quantization on baseline (no pruning) ===
print("=== Apply quantization on baseline (no pruning) ===")
resnet18_quant = quantize_model(resnet18, train_loader)
evaluate(resnet18_quant, test_loader, quantization=True)
model_size_mb(resnet18_quant)
measure_inference_speed(resnet18_quant, test_loader, quantization=True)

# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
resnet18_quant_pr = quantize_trained_pruned_model(resnet18_pruned_st, train_loader)
evaluate(resnet18_quant_pr, test_loader, quantization=True)
model_size_mb(resnet18_quant_pr)
measure_inference_speed(resnet18_quant_pr, test_loader, quantization=True)

=== Apply quantization on baseline (no pruning) ===




Test Accuracy = 48.16%
Model Size (state_dict): 10.79 MB
Avg Inference Time per Batch: 0.1500 sec
=== Apply quantization on pruned model ===
Test Accuracy = 44.82%
Model Size (state_dict): 10.79 MB
Avg Inference Time per Batch: 0.1555 sec


0.15552488681489388