In [6]:
# 1. Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm

# 2. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(lfw_dataset.classes)

# 3. Load pre-trained ResNet18 and adjust
model = models.resnet18(weights='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# 4. Loss & Optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 5. Training loop (20 batches)
model.train()
train_iter = iter(lfw_loader)
for batch_idx in range(20):
    images, labels = next(train_iter)
    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    print(f"Batch {batch_idx+1}/20 - Loss: {loss.item():.4f}")

# 6. FGSM Attack
def fgsm_attack(image, epsilon, data_grad):
    perturbed = image + epsilon * data_grad.sign()
    return torch.clamp(perturbed, 0, 1)

# Get one image
image, label = next(iter(lfw_loader))
image, label = image[0].unsqueeze(0).to(device), torch.tensor([label[0]]).to(device)
image.requires_grad = True

# Predict and check correct
output = model(image)
init_pred = output.argmax(dim=1)
if init_pred.item() != label.item():
    print("Initial prediction was incorrect; skipping FGSM attack.")
else:
    loss = criterion(output, label)
    model.zero_grad()
    loss.backward()
    data_grad = image.grad.data
    epsilon = 0.1
    perturbed_image = fgsm_attack(image, epsilon, data_grad)

    model.eval()
    with torch.no_grad():
        adv_output = model(perturbed_image)

    adv_pred = adv_output.argmax(dim=1)
    print(f"Original prediction: {lfw_dataset.classes[init_pred.item()]}")
    print(f"Adversarial prediction: {lfw_dataset.classes[adv_pred.item()]}")


Batch 1/20 - Loss: 9.0471
Batch 2/20 - Loss: 8.6899
Batch 3/20 - Loss: 8.3191
Batch 4/20 - Loss: 8.7172
Batch 5/20 - Loss: 8.6107
Batch 6/20 - Loss: 8.9704
Batch 7/20 - Loss: 9.0471
Batch 8/20 - Loss: 8.3831
Batch 9/20 - Loss: 9.2827
Batch 10/20 - Loss: 8.1877
Batch 11/20 - Loss: 9.7383
Batch 12/20 - Loss: 8.5505
Batch 13/20 - Loss: 9.4257
Batch 14/20 - Loss: 9.5584
Batch 15/20 - Loss: 9.5067
Batch 16/20 - Loss: 9.5039
Batch 17/20 - Loss: 8.8694
Batch 18/20 - Loss: 8.6729
Batch 19/20 - Loss: 7.3447
Batch 20/20 - Loss: 8.9476
Initial prediction was incorrect; skipping FGSM attack.
