In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import timm
import copy
import os
import tempfile
from torch.ao.quantization import get_default_qconfig, prepare, convert
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.qconfig import QConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("msambare/fer2013")

print("Path to dataset files:", path)

Path to dataset files: /home/ssukuma2/.cache/kagglehub/datasets/msambare/fer2013/versions/1


In [3]:
# Dataset loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = ImageFolder('../FER2013/train', transform=transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
test_dataset = ImageFolder('../FER2013/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check distribution
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Train samples: 22968
Validation samples: 5741
Test samples: 7178


In [4]:
vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=7).to(device)

In [5]:
# Unstructured Pruning for CNN
def unstructured_prune_cnn(model, amount=0.3):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Pruning for CNN
def structured_prune_cnn(model, amount=0.5):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
            prune.remove(module, "weight")
    return model

In [6]:
# Unstructured Pruning for ViT
def unstructured_prune_vit(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Attention Head Pruning for ViT
def prune_vit_attention_heads(model, heads_to_prune=2):
    for name, module in model.named_modules():
        if hasattr(module, 'qkv') and hasattr(module, 'num_heads'):
            heads_dim = module.qkv.weight.shape[0] // 3
            head_size = heads_dim // module.num_heads
            qkv_weights = module.qkv.weight.data.view(3, module.num_heads, head_size, -1)
            norms = qkv_weights.norm(dim=(2, 3))
            importance = norms.sum(dim=0)
            prune_indices = torch.topk(importance, heads_to_prune, largest=False).indices
            for i in prune_indices:
                qkv_weights[:, i, :, :] = 0
            module.qkv.weight.data = qkv_weights.view(-1, module.qkv.weight.shape[1])
    return model

In [7]:
def fuse_model_blocks(model):
    torch.quantization.fuse_modules(model, [["conv1", "bn1", "relu"]], inplace=True)
    for module_name, module in model.named_children():
        if "layer" in module_name:
            for block in module:
                torch.quantization.fuse_modules(
                    block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
                    inplace=True
                )
                if hasattr(block, "downsample") and isinstance(block.downsample, torch.nn.Sequential):
                    if len(block.downsample) >= 2:
                        torch.quantization.fuse_modules(block.downsample, ["0", "1"], inplace=True)

In [51]:
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping

# --- Custom quantization-safe LayerNorm replacement ---
class QuantLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(*self.normalized_shape))
            self.bias = nn.Parameter(torch.zeros(*self.normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        x = (x - mean) / torch.sqrt(var + self.eps)
        if self.elementwise_affine:
            x = x * self.weight + self.bias
        return x

# --- Main function to quantize a pruned ViT model ---
def quantize_pruned_vit_model(model, calibration_loader, num_batches=10):
    import copy

    # Clone the model to avoid modifying the original
    model = copy.deepcopy(model)
    model.eval().cpu()

    # Replace incompatible modules (GELU, LayerNorm)
    def patch_for_static_quant(model):
        replacements = []

        for name, module in model.named_modules():
            if isinstance(module, nn.GELU):
                replacements.append((name, nn.ReLU()))
            elif isinstance(module, nn.LayerNorm):
                try:
                    qln = QuantLayerNorm(module.normalized_shape, eps=module.eps, elementwise_affine=True)
                    qln.weight.data = module.weight.data.clone()
                    qln.bias.data = module.bias.data.clone()
                    replacements.append((name, qln))
                except Exception as e:
                    print(f"Skipping {name} due to shape mismatch: {e}")

        # Apply replacements safely after iteration
        for name, new_module in replacements:
            parent = model
            parts = name.split(".")
            for part in parts[:-1]:
                parent = getattr(parent, part)
            setattr(parent, parts[-1], new_module)

        return model

    model = patch_for_static_quant(model)

    # Quantization config
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")

    # Get example input for FX tracing
    example_input = next(iter(calibration_loader))[0]

    # Prepare for FX static quantization
    prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=example_input)

    # Calibration loop
    with torch.no_grad():
        for i, (x, _) in enumerate(calibration_loader):
            x = x.to(torch.float32)
            prepared_model(x)
            if i >= num_batches:
                break

    # Convert to quantized model
    quantized_model = convert_fx(prepared_model)
    return quantized_model


In [42]:
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping


def quantize_vit(model, calibration_loader, num_batches=10):
    model.eval()
    model.cpu()

    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    example_input = next(iter(calibration_loader))[0]

    # FX Graph Mode Quantization
    prepared = prepare_fx(model, qconfig_mapping, example_inputs=example_input)

    with torch.no_grad():
        for i, (x, _) in enumerate(calibration_loader):
            prepared(x.to(torch.float32))
            x.to("cpu")
            model(x)
            if i >= num_batches:
                break

    quantized_model = convert_fx(prepared)
    return quantized_model

In [10]:
import time

def train(model, train_loader, val_loader, epochs=10, lr=1e-4):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss += loss.item()
                correct += (outputs.argmax(1) == y).sum().item()
                total += y.size(0)

        train_loss = running_loss / len(train_loader)
        val_loss /= len(val_loader)
        val_acc = 100 * correct / total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.2f}%")

In [11]:
def evaluate(model, test_loader, quantization=False):
    model.eval()

    if quantization:
        model = model.to("cpu")  # Quantized models must be on CPU
    else:
        model = model.to("cuda")

    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            if quantization:
                x, y = x.to("cpu"), y.to("cpu")
            else:
                x, y = x.to("cuda"), y.to("cuda")

            outputs = model(x)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    accuracy = 100 * correct / total
    print(f"Test Accuracy = {accuracy:.2f}%")
    return accuracy

In [12]:
def measure_inference_speed(model, test_loader, quantization=False):
    model.eval()
    
    if quantization:
        device = "cpu"
        
    else:
        device = torch.device("cuda")
    
    model.to(device)
    
    start = time.time()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            _ = model(x)
    end = time.time()
    latency = (end - start) / len(test_loader)
    print(f"Avg Inference Time per Batch: {latency:.4f} sec")
    return latency

def model_size_mb(model, use_state_dict=True):
    with tempfile.NamedTemporaryFile(delete=False) as f:
        if use_state_dict:
            torch.save(model.state_dict(), f.name)
        else:
            torch.save(model, f.name)
        size_mb = os.path.getsize(f.name) / (1024 * 1024)
    print(f"Model Size ({'state_dict' if use_state_dict else 'full model'}): {size_mb:.2f} MB")
    return size_mb

In [None]:
# === Train baseline ===
print("=== Train baseline ===")
train(vit, train_loader, val_loader, epochs=1)
evaluate(vit, test_loader)
model_size_mb(vit)
measure_inference_speed(vit, test_loader)


=== Train baseline ===
Epoch 1: Train Loss = 1.7604, Val Loss = 1.5743, Val Acc = 38.37%
Test Accuracy = 40.40%
Model Size (state_dict): 327.37 MB
Avg Inference Time per Batch: 0.8452 sec
=== Apply structured pruning, then fine-tune ===


TypeError: prune_vit_attention_heads() got an unexpected keyword argument 'amount'

In [14]:
# === Apply structured pruning, then fine-tune ===
print("=== Apply structured pruning, then fine-tune ===")
vit_pruned_st = prune_vit_attention_heads(vit)
train(vit_pruned_st, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(vit_pruned_st, test_loader)
model_size_mb(vit_pruned_st)
measure_inference_speed(vit_pruned_st, test_loader)

# === Apply unstructured pruning, then fine-tune ===
print("=== Apply unstructured pruning, then fine-tune ===")
vit_pruned_unst = unstructured_prune_vit(vit, amount=0.5)
train(vit_pruned_unst, train_loader, val_loader, epochs=1, lr=1e-5)
evaluate(vit_pruned_unst, test_loader)
model_size_mb(vit_pruned_unst)
measure_inference_speed(vit_pruned_unst, test_loader)

=== Apply structured pruning, then fine-tune ===
Epoch 1: Train Loss = 1.4207, Val Loss = 1.3857, Val Acc = 46.84%
Test Accuracy = 47.10%
Model Size (state_dict): 327.37 MB
Avg Inference Time per Batch: 0.8827 sec
=== Apply unstructured pruning, then fine-tune ===
Epoch 1: Train Loss = 1.3008, Val Loss = 1.2828, Val Acc = 51.38%
Test Accuracy = 51.48%
Model Size (state_dict): 327.37 MB
Avg Inference Time per Batch: 0.8392 sec


0.8391690359706372

In [43]:
# === Apply quantization on baseline (no pruning) ===
print("=== Apply quantization on baseline (no pruning) ===")
vit_quant = quantize_vit(vit, train_loader)
evaluate(vit_quant, test_loader, quantization=True)
model_size_mb(vit_quant)
measure_inference_speed(vit_quant, test_loader, quantization=True)



=== Apply quantization on baseline (no pruning) ===


  cat = torch.cat([quantize_per_tensor_6, patch_embed_norm], dim = 1);  quantize_per_tensor_6 = patch_embed_norm = None


Test Accuracy = 19.27%
Model Size (state_dict): 84.06 MB
Avg Inference Time per Batch: 2.3836 sec


2.383622593584314

In [52]:
# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
vit_quant_pr = quantize_pruned_vit_model(vit_pruned_st, train_loader)
evaluate(vit_quant_pr, test_loader, quantization=True)
model_size_mb(vit_quant_pr)
measure_inference_speed(vit_quant_pr, test_loader, quantization=True)

=== Apply quantization on pruned model ===


  cat = torch.cat([quantize_per_tensor_6, patch_embed_norm], dim = 1);  quantize_per_tensor_6 = patch_embed_norm = None


Test Accuracy = 22.72%
Model Size (state_dict): 84.12 MB
Avg Inference Time per Batch: 1.7040 sec


1.7040463654340896