In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import time
import timm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [15]:
import os

FER_CLASSES = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
base_dir = 'fer2013/versions/1/train'

# Auto-create missing class folders if not present
for cls in FER_CLASSES:
    cls_path = os.path.join(base_dir, cls)
    if not os.path.exists(cls_path):
        os.makedirs(cls_path)
        print(f"⚠️ Created empty folder: {cls_path}")

# Remove unsupported or hidden files
for cls in FER_CLASSES:
    cls_path = os.path.join(base_dir, cls)
    if not os.path.exists(cls_path): continue
    for f in os.listdir(cls_path):
        full_path = os.path.join(cls_path, f)
        if not f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')):
            print(f"⛔ Skipping invalid file: {full_path}")
            os.remove(full_path)

In [18]:
for cls in FER_CLASSES:
    cls_path = os.path.join('fer2013/versions/1/test', cls)
    files = [f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'))]
    print(f"{cls}: {len(files)} images")

angry: 958 images
disgust: 111 images
fear: 1024 images
happy: 1774 images
sad: 1247 images
surprise: 831 images
neutral: 1233 images


In [14]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("msambare/fer2013")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/msambare/fer2013?dataset_version_number=1...


100%|██████████| 60.3M/60.3M [00:00<00:00, 86.8MB/s]

Extracting files...





Path to dataset files: /home/smahadi/.cache/kagglehub/datasets/msambare/fer2013/versions/1


In [19]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Required for ViTs/CNNs expecting 3 channels
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = ImageFolder(root='fer2013/versions/1/train', transform=transform)
test_dataset = ImageFolder(root='fer2013/versions/1/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

In [30]:
# CNN Models
vgg11 = timm.create_model("vgg11", pretrained=True, num_classes=7).to(device)
resnet18 = timm.create_model("resnet18", pretrained=True, num_classes=7).to(device)

# Vision Transformer
vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=7).to(device)

In [21]:
def train(model, loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    model.train()

    for epoch in range(epochs):
        correct = total = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)
        acc = 100 * correct / total
        print(f"Epoch {epoch+1}: Accuracy = {acc:.2f}%")

In [22]:
def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)
    acc = 100 * correct / total
    print(f"Test Accuracy = {acc:.2f}%")
    return acc

In [23]:
import torch.nn.utils.prune as prune

def apply_pruning(model, amount=0.3):
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
    return model

In [24]:
def quantize_model(model):
    model.eval()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    model_fp32_prepared = torch.quantization.prepare(model, inplace=False)

    # Run a few batches to calibrate
    with torch.no_grad():
        for x, _ in train_loader:
            x = x.to(device)
            model_fp32_prepared(x)
            break

    quantized_model = torch.quantization.convert(model_fp32_prepared, inplace=False)
    return quantized_model

In [25]:
def measure_latency(model, loader):
    model.eval()
    start = time.time()
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            _ = model(x)
    end = time.time()
    latency = (end - start) / len(loader)
    print(f"Avg. Inference Latency: {latency:.4f} seconds per batch")
    return latency

In [26]:
def model_size_mb(model):
    param_size = sum(p.numel() for p in model.parameters()) * 4 / (1024 ** 2)
    print(f"Model Size: {param_size:.2f} MB")
    return param_size

In [31]:
# Train baseline ResNet18
print("\n--- Training ResNet18 (Baseline) ---")
train(resnet18, train_loader, epochs=5)
evaluate(resnet18, test_loader)
model_size_mb(resnet18)
measure_latency(resnet18, test_loader)

# Apply pruning
print("\n--- Pruning ResNet18 ---")
resnet18_pruned = apply_pruning(resnet18, amount=0.3)
train(resnet18_pruned, train_loader, epochs=5)
evaluate(resnet18_pruned, test_loader)
model_size_mb(resnet18_pruned)
measure_latency(resnet18_pruned, test_loader)

# Apply quantization
print("\n--- Quantizing ResNet18 ---")
resnet18_quant = quantize_model(resnet18)
evaluate(resnet18_quant, test_loader)
model_size_mb(resnet18_quant)
measure_latency(resnet18_quant, test_loader)


--- Training ResNet18 (Baseline) ---
Epoch 1: Accuracy = 36.00%
Epoch 2: Accuracy = 53.44%
Epoch 3: Accuracy = 60.33%
Epoch 4: Accuracy = 65.48%
Epoch 5: Accuracy = 70.59%
Test Accuracy = 61.27%
Model Size: 42.65 MB
Avg. Inference Latency: 0.1168 seconds per batch

--- Pruning ResNet18 ---
Epoch 1: Accuracy = 74.39%
Epoch 2: Accuracy = 78.57%
Epoch 3: Accuracy = 82.69%
Epoch 4: Accuracy = 86.69%
Epoch 5: Accuracy = 90.39%
Test Accuracy = 62.87%
Model Size: 42.65 MB
Avg. Inference Latency: 0.1170 seconds per batch

--- Quantizing ResNet18 ---




RuntimeError: Unsupported qscheme: per_channel_affine

In [None]:
# Train baseline ResNet18
print("\n--- Training VGG11 (Baseline) ---")
train(vgg11, train_loader, epochs=5)
evaluate(vgg11, test_loader)
model_size_mb(vgg11)
measure_latency(vgg11, test_loader)

# Apply pruning
print("\n--- Pruning VGG11 ---")
vgg11_pruned = apply_pruning(vgg11, amount=0.3)
evaluate( vgg11_pruned, test_loader)
model_size_mb(vgg11_pruned)
measure_latency(vgg11_pruned, test_loader)

# Apply quantization
print("\n--- Quantizing VGG11 ---")
vgg11_quant = quantize_model(vgg11)
evaluate(vgg11_quant, test_loader)
model_size_mb(vgg11_quant)
measure_latency(vgg11_quant, test_loader)

In [29]:
import gc

del resnet18
del resnet18_pruned
gc.collect()

9179