## Imports

In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
import numpy as np
import random

## Set seed for reproducibility

In [2]:
def set_seed(seed):
    """
    Set the seed for reproducibility.

    Args:
        seed (int): Seed value to set for random number generation.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Set seed to be 42
set_seed(42)

## Resnet-18

## 5-fold Cross Validation

In [3]:
# Hyperparameters
best_hyperparams = {'lr': 0.01, 'batch_size': 16}

In [4]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the ResNet18 model
def create_resnet_model(num_classes):
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43636918, 0.38563913, 0.34477144],
                         std=[0.29639485, 0.2698132, 0.26158142])
])

# Set the dataset folder
dataset_folder = '../data/lfw/'

# Load datasets
dataset = datasets.ImageFolder(root=dataset_folder, transform=transform)

# Get targets and number of classes
targets = np.array(dataset.targets)
num_classes = len(torch.unique(torch.tensor(targets)))

# Define 10-fold stratified cross-validation
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Early stopping criteria
early_stopping_patience = 5

# List to store the results for each fold
fold_results = []

# 10-fold cross-validation loop
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
    print(f'Fold {fold+1}')
    
    # Create training and validation subsets
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    
    # Create data loaders
    train_loader = DataLoader(train_subset, batch_size=best_hyperparams['batch_size'], shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=best_hyperparams['batch_size'], shuffle=False)
    
    # Initialize model, loss function, and optimizer
    model = create_resnet_model(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])
    
    # Initialize early stopping parameters
    best_val_loss = float('inf')
    early_stopping_counter = 0
    best_val_acc = 0
    best_f1 = 0
    
    # Training loop
    for epoch in range(100):  # Modify number of epochs if needed
        model.train()
        running_loss = 0.0
        correct_preds = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            correct_preds += (outputs.argmax(1) == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / len(train_loader.dataset)
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_correct = 0
        all_labels = []
        all_preds = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()
                
                # Store predictions and labels for F1 score calculation
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(outputs.argmax(1).cpu().numpy())
        
        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)
        
        # Calculate F1 score (macro-average)
        f1 = f1_score(all_labels, all_preds, average='macro')
        
        print(f'Epoch {epoch+1} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}, F1 Score: {f1:.4f}')
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            best_f1 = f1
            early_stopping_counter = 0
            torch.save(model.state_dict(), f'best_model_fold_{fold+1}.pth')  # Save the best model for this fold
        else:
            early_stopping_counter += 1
        
        if early_stopping_counter >= early_stopping_patience:
            print(f'Early stopping triggered after {early_stopping_counter} epochs without improvement.')
            break
    
    # Store the best validation loss, accuracy, and F1 score for this fold
    fold_results.append({'fold': fold+1, 'val_loss': best_val_loss, 'val_acc': best_val_acc, 'f1_score': best_f1})

# Print cross-validation results
print("\nCross-validation results:")
avg_val_loss = np.mean([result['val_loss'] for result in fold_results])
avg_val_acc = np.mean([result['val_acc'] for result in fold_results])
avg_f1_score = np.mean([result['f1_score'] for result in fold_results])

for result in fold_results:
    print(f"Fold {result['fold']}: Validation Loss = {result['val_loss']:.4f}, Validation Accuracy = {result['val_acc']:.4f}, F1 Score = {result['f1_score']:.4f}")

print(f"\nAverage Validation Loss: {avg_val_loss:.4f}")
print(f"Average Validation Accuracy: {avg_val_acc:.4f}")
print(f"Average F1 Score: {avg_f1_score:.4f}")

Fold 1
Epoch 1 - Loss: 2.4300, Accuracy: 0.3244
Validation Loss: 3.0236, Validation Accuracy: 0.1712, F1 Score: 0.0594
Epoch 2 - Loss: 2.0443, Accuracy: 0.3534
Validation Loss: 2.0489, Validation Accuracy: 0.3288, F1 Score: 0.0588
Epoch 3 - Loss: 1.9891, Accuracy: 0.3641
Validation Loss: 2.0182, Validation Accuracy: 0.3151, F1 Score: 0.0508
Epoch 4 - Loss: 1.9412, Accuracy: 0.3611
Validation Loss: 1.9834, Validation Accuracy: 0.3082, F1 Score: 0.0854
Epoch 5 - Loss: 1.9273, Accuracy: 0.3603
Validation Loss: 1.9085, Validation Accuracy: 0.3356, F1 Score: 0.0628
Epoch 6 - Loss: 1.8772, Accuracy: 0.3687
Validation Loss: 1.9762, Validation Accuracy: 0.3219, F1 Score: 0.0796
Epoch 7 - Loss: 1.8545, Accuracy: 0.3809
Validation Loss: 2.0505, Validation Accuracy: 0.3151, F1 Score: 0.0578
Epoch 8 - Loss: 1.8148, Accuracy: 0.3916
Validation Loss: 2.0388, Validation Accuracy: 0.3699, F1 Score: 0.1466
Epoch 9 - Loss: 1.7672, Accuracy: 0.4130
Validation Loss: 1.8813, Validation Accuracy: 0.3699, F1