In [7]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("msambare/fer2013")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/msambare/fer2013?dataset_version_number=1...


100%|██████████| 60.3M/60.3M [00:00<00:00, 83.8MB/s]

Extracting files...





Path to dataset files: /home/smahadi/.cache/kagglehub/datasets/msambare/fer2013/versions/1


In [None]:
# Dataset loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = ImageFolder('fer2013/versions/1/train', transform=transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
test_dataset = ImageFolder('fer2013/versions/1/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check distribution
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Train samples: 22968
Validation samples: 5741
Test samples: 7178


In [8]:
# 1. ResNet-18 (CNN)
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, 7)  # FER2013 has 7 emotion classes
resnet18 = resnet18.to(device)

# 2. VGG-11 (CNN)
vgg11 = models.vgg11(pretrained=True)
vgg11.classifier[6] = nn.Linear(4096, 7)
vgg11 = vgg11.to(device)

# 3. ViT-B/16 (Vision Transformer)
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=7)
vit = vit.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/smahadi/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 55.4MB/s]
Downloading: "https://download.pytorch.org/models/vgg11-8a719046.pth" to /home/smahadi/.cache/torch/hub/checkpoints/vgg11-8a719046.pth
100%|██████████| 507M/507M [00:03<00:00, 143MB/s]  


In [9]:
# Unstructured Pruning for CNN
def unstructured_prune_cnn(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Pruning for CNN
def structured_prune_cnn(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
            prune.remove(module, "weight")
    return model


In [10]:
# Unstructured Pruning for ViT
def unstructured_prune_vit(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

# Structured Attention Head Pruning for ViT
def prune_vit_attention_heads(model, heads_to_prune=2):
    for name, module in model.named_modules():
        if hasattr(module, 'qkv') and hasattr(module, 'num_heads'):
            heads_dim = module.qkv.weight.shape[0] // 3
            head_size = heads_dim // module.num_heads
            qkv_weights = module.qkv.weight.data.view(3, module.num_heads, head_size, -1)
            norms = qkv_weights.norm(dim=(2, 3))
            importance = norms.sum(dim=0)
            prune_indices = torch.topk(importance, heads_to_prune, largest=False).indices
            for i in prune_indices:
                qkv_weights[:, i, :, :] = 0
            module.qkv.weight.data = qkv_weights.view(-1, module.qkv.weight.shape[1])
    return model

In [11]:
# Static INT8 Quantization (Post-training)
def quantize_cnn_static(model, loader):
    model.eval()
    model = model.to('cpu')
    fused_model = torch.quantization.fuse_modules(
        model,
        [["conv1", "bn1", "relu"]] +
        [[f"layer{i}.{j}.conv1", f"layer{i}.{j}.bn1", f"layer{i}.{j}.relu"]
         for i in range(1, 5) for j in range(len(getattr(model, f"layer{i}")))],
        inplace=False
    )
    fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    prepared_model = torch.quantization.prepare(fused_model)
    with torch.no_grad():
        for x, _ in loader:
            prepared_model(x)
            break
    quantized_model = torch.quantization.convert(prepared_model)
    return quantized_model

In [12]:
import time

def train(model, train_loader, val_loader, epochs=10, lr=1e-4):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss += loss.item()
                correct += (outputs.argmax(1) == y).sum().item()
                total += y.size(0)

        train_loss = running_loss / len(train_loader)
        val_loss /= len(val_loader)
        val_acc = 100 * correct / total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.2f}%")

In [None]:
def evaluate(model, test_loader):
    model.eval()
    model.to("cpu")
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            correct += (outputs.argmax(1) == y).sum().item()
            total += y.size(0)
    accuracy = 100 * correct / total
    print(f"Test Accuracy = {accuracy:.2f}%")
    return accuracy

In [None]:
def measure_inference_speed(model, test_loader):
    model.eval()
    model.to("cpu")
    start = time.time()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            _ = model(x)
    end = time.time()
    latency = (end - start) / len(test_loader)
    print(f"Avg Inference Time per Batch: {latency:.4f} sec")
    return latency

def model_size_mb(model):
    total_params = sum(p.numel() for p in model.parameters())
    size_mb = total_params * 4 / (1024**2)
    print(f"Model Size: {size_mb:.2f} MB")
    return size_mb

In [None]:
from sklearn.cluster import KMeans
def deep_compression_quantize(model, num_clusters=32):
    codebooks = {}
    index_maps = {}

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Get weight tensor
            W = module.weight.detach().cpu().numpy()
            shape = W.shape

            # Flatten weights for clustering
            W_flat = W.flatten().reshape(-1, 1)

            # K-means clustering
            kmeans = KMeans(n_clusters=num_clusters, n_init=10, max_iter=300, random_state=42)
            kmeans.fit(W_flat)

            centroids = kmeans.cluster_centers_.flatten()  # shape: (num_clusters,)
            labels = kmeans.labels_  # shape: (num_weights,)

            # Replace each weight with its corresponding centroid
            W_compressed = centroids[labels].reshape(shape)
            module.weight.data = torch.tensor(W_compressed, dtype=module.weight.dtype, device=module.weight.device)

            # Save codebook + index map
            codebooks[name] = centroids
            index_maps[name] = labels.reshape(shape)

            print(f"✅ Quantized layer: {name} — original shape: {shape}, clusters: {num_clusters}")

    return model, codebooks, index_maps

In [16]:
# === Train baseline ===
print("=== Train baseline ===")
train(resnet18, train_loader, val_loader, epochs=20)
evaluate(resnet18, test_loader)
model_size_mb(resnet18)
measure_inference_speed(resnet18, test_loader)

# === Apply structured pruning, then fine-tune ===
print("=== Apply structured pruning, then fine-tune")
resnet18_pruned_st = structured_prune_cnn(resnet18, amount=0.3)
train(resnet18_pruned_st, train_loader, val_loader, epochs=5, lr=1e-5)
evaluate(resnet18_pruned_st, test_loader)
model_size_mb(resnet18_pruned_st)
measure_inference_speed(resnet18_pruned_st, test_loader)

# === Apply unstructured pruning, then fine-tune ===
print("=== Apply unstructured pruning, then fine-tune")
resnet18_pruned_unst = unstructured_prune_cnn(resnet18, amount=0.3)
train(resnet18_pruned_unst, train_loader, val_loader, epochs=5, lr=1e-5)
evaluate(resnet18_pruned_unst, test_loader)
model_size_mb(resnet18_pruned_unst)
measure_inference_speed(resnet18_pruned_unst, test_loader)

# === Apply quantization on baseline (no pruning) ===
print("=== Apply quantization on baseline (no pruning) ===")
resnet18_quant = quantize_cnn_static(resnet18, train_loader)
evaluate(resnet18_quant, test_loader)
model_size_mb(resnet18_quant)
measure_inference_speed(resnet18_quant, test_loader)

# === Apply quantization on pruned model ===
print("=== Apply quantization on pruned model ===")
resnet18_quant_pr = quantize_cnn_static(resnet18_pruned_st, train_loader)
evaluate(resnet18_quant_pr, test_loader)
model_size_mb(resnet18_quant_pr)
measure_inference_speed(resnet18_quant_pr, test_loader)

=== Train baseline ===
Epoch 1: Train Loss = 0.9218, Val Loss = 0.9770, Val Acc = 64.00%
Epoch 2: Train Loss = 0.5687, Val Loss = 1.1052, Val Acc = 62.95%
Epoch 3: Train Loss = 0.2404, Val Loss = 1.2244, Val Acc = 63.02%
Epoch 4: Train Loss = 0.0980, Val Loss = 1.3237, Val Acc = 65.60%
Epoch 5: Train Loss = 0.0492, Val Loss = 1.3581, Val Acc = 64.85%
Epoch 6: Train Loss = 0.0393, Val Loss = 1.5328, Val Acc = 64.97%
Epoch 7: Train Loss = 0.0767, Val Loss = 1.6254, Val Acc = 63.96%
Epoch 8: Train Loss = 0.1276, Val Loss = 1.6225, Val Acc = 62.71%
Epoch 9: Train Loss = 0.0571, Val Loss = 1.6189, Val Acc = 63.75%
Epoch 10: Train Loss = 0.0349, Val Loss = 1.6120, Val Acc = 65.32%
Epoch 11: Train Loss = 0.0267, Val Loss = 1.6874, Val Acc = 64.57%
Epoch 12: Train Loss = 0.0163, Val Loss = 1.7041, Val Acc = 65.96%
Epoch 13: Train Loss = 0.0791, Val Loss = 1.9050, Val Acc = 60.76%
Epoch 14: Train Loss = 0.1029, Val Loss = 1.6870, Val Acc = 64.87%
Epoch 15: Train Loss = 0.0364, Val Loss = 1.7789



NotImplementedError: Could not run 'quantized::conv2d_relu.new' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d_relu.new' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:2045 [kernel]
QuantizedCUDA: registered at /pytorch/aten/src/ATen/native/quantized/cudnn/Conv.cpp:391 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:503 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:100 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:91 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMTIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:95 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:465 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:499 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]


In [None]:
# Train baseline ResNet18
print("\n--- Training VGG11 (Baseline) ---")
train(vgg11, train_loader, epochs=5)
evaluate(vgg11, test_loader)
model_size_mb(vgg11)
measure_latency(vgg11, test_loader)

# Apply pruning
print("\n--- Pruning VGG11 ---")
vgg11_pruned = apply_pruning(vgg11, amount=0.3)
evaluate( vgg11_pruned, test_loader)
model_size_mb(vgg11_pruned)
measure_latency(vgg11_pruned, test_loader)

# Apply quantization
print("\n--- Quantizing VGG11 ---")
vgg11_quantm, _, _ = deep_compression_quantize(vgg11)
evaluate(vgg11_quant, test_loader)
model_size_mb(vgg11_quant)
measure_latency(vgg11_quant, test_loader)

In [None]:
import gc

del resnet18
gc.collect()

690