In [12]:
#Setup + Data Loading

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io


# 1.2 Set seed and device
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1.3 Dataset path
lfw_path = "../Datasets/lfw-dataset"

# 1.4 Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # For ResNet input
    transforms.ToTensor(),
])

# 1.5 Load dataset
lfw_dataset = datasets.ImageFolder(root=lfw_path, transform=transform)
lfw_loader = DataLoader(lfw_dataset, batch_size=16, shuffle=True)


Using device: cpu


In [13]:
#Define + Train ResNet (small batch for demo)

model = models.resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, len(lfw_dataset.classes))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for batch_idx, (images, labels) in enumerate(lfw_loader):
    if batch_idx >= 30: break
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f"Batch {batch_idx+1}/30 - Loss: {loss.item():.4f}")

Batch 1/30 - Loss: 9.0540
Batch 2/30 - Loss: 8.7226
Batch 3/30 - Loss: 8.8491
Batch 4/30 - Loss: 9.0046
Batch 5/30 - Loss: 8.3589
Batch 6/30 - Loss: 9.2388
Batch 7/30 - Loss: 9.0561
Batch 8/30 - Loss: 8.8714
Batch 9/30 - Loss: 8.5220
Batch 10/30 - Loss: 8.9919
Batch 11/30 - Loss: 9.5368
Batch 12/30 - Loss: 10.0698
Batch 13/30 - Loss: 9.0608
Batch 14/30 - Loss: 8.7379
Batch 15/30 - Loss: 9.3890
Batch 16/30 - Loss: 8.7246
Batch 17/30 - Loss: 9.0874
Batch 18/30 - Loss: 8.9169
Batch 19/30 - Loss: 9.2737
Batch 20/30 - Loss: 9.5900
Batch 21/30 - Loss: 9.2163
Batch 22/30 - Loss: 8.7496
Batch 23/30 - Loss: 8.4844
Batch 24/30 - Loss: 9.3101
Batch 25/30 - Loss: 9.3439
Batch 26/30 - Loss: 9.2269
Batch 27/30 - Loss: 10.1566
Batch 28/30 - Loss: 9.1716
Batch 29/30 - Loss: 8.8061
Batch 30/30 - Loss: 9.5924


In [14]:
#FGSM Attack (loop until model gets one right)

def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed = image + epsilon * sign_data_grad
    return torch.clamp(perturbed, 0, 1)

model.eval()
image_found = False

for _ in range(30):
    image, label = next(iter(lfw_loader))
    image = image[0].unsqueeze(0).to(device)
    label = torch.tensor([label[0]]).to(device)
    image.requires_grad = True

    output = model(image)
    init_pred = output.argmax(dim=1)

    if init_pred.item() == label.item():
        print("✅ Found correct prediction:", lfw_dataset.classes[init_pred.item()])
        loss = criterion(output, label)
        model.zero_grad()
        loss.backward()
        data_grad = image.grad.data
        epsilon = 0.1
        perturbed_image = fgsm_attack(image, epsilon, data_grad)
        clean_pred = init_pred.item()
        break

if not image_found:
    print("❌ Could not find a correctly classified image.")
    

✅ Found correct prediction: George_W_Bush
❌ Could not find a correctly classified image.


In [15]:
#JPEG Defence
def jpeg_defence(image_tensor, quality=30):
    pil_img = transforms.ToPILImage()(image_tensor.squeeze().cpu())
    buffer = io.BytesIO()
    pil_img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    compressed_img = Image.open(buffer)
    return transforms.ToTensor()(compressed_img).unsqueeze(0).to(device)

jpeg_image = jpeg_defence(perturbed_image)

model.eval()
with torch.no_grad():
    adv_out = model(perturbed_image)
    jpeg_out = model(jpeg_image)

adv_pred = adv_out.argmax(dim=1).item()
jpeg_pred = jpeg_out.argmax(dim=1).item()

print(f"🎯 Clean Prediction: {lfw_dataset.classes[clean_pred]}")
print(f"⚠️  Adversarial Prediction: {lfw_dataset.classes[adv_pred]}")
print(f"🔧 JPEG Defence Prediction: {lfw_dataset.classes[jpeg_pred]}")


🎯 Clean Prediction: George_W_Bush
⚠️  Adversarial Prediction: George_W_Bush
🔧 JPEG Defence Prediction: George_W_Bush
