In [2]:
# 2.1 Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# 2.2 Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

lfw_path = "../Datasets/lfw-dataset"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 2.3 Load dataset
lfw_dataset = datasets.ImageFolder(root=lfw_path, transform=transform)
lfw_loader = DataLoader(lfw_dataset, batch_size=16, shuffle=True)

num_classes = len(lfw_dataset.classes)
print(f"Number of classes: {num_classes}")

# 2.4 Load pretrained ResNet18
resnet = models.resnet18(weights="IMAGENET1K_V1")
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
resnet = resnet.to(device)

# 2.5 Loss and optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=0.001)

# 2.6 Train for 1 epoch (limited batches for demo)
resnet.train()
for batch_idx, (images, labels) in enumerate(lfw_loader):
    if batch_idx >= 20:
        break
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = resnet(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f"Batch {batch_idx+1}/20 - Loss: {loss.item():.4f}")

# 2.7 Save model
torch.save(resnet.state_dict(), "../Models/resnet18_lfw.pth")


Using device: cpu
Number of classes: 5749
Batch 1/20 - Loss: 8.8171
Batch 2/20 - Loss: 8.9891
Batch 3/20 - Loss: 9.1070
Batch 4/20 - Loss: 8.8183
Batch 5/20 - Loss: 8.7465
Batch 6/20 - Loss: 8.9638
Batch 7/20 - Loss: 8.7879
Batch 8/20 - Loss: 8.6742
Batch 9/20 - Loss: 8.4216
Batch 10/20 - Loss: 9.1010
Batch 11/20 - Loss: 8.5683
Batch 12/20 - Loss: 9.9553
Batch 13/20 - Loss: 10.1958
Batch 14/20 - Loss: 9.7819
Batch 15/20 - Loss: 9.2203
Batch 16/20 - Loss: 9.9391
Batch 17/20 - Loss: 10.0719
Batch 18/20 - Loss: 8.5529
Batch 19/20 - Loss: 8.9700
Batch 20/20 - Loss: 9.9950
