In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
import numpy as np
import json

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Store the class mappings
def store_class_mappings(dataset, file_path):
    class_mappings = {i: cls for i, cls in enumerate(dataset.classes)}
    with open(file_path, 'w') as f:
        json.dump(class_mappings, f)
    print(f"Class mappings saved to {file_path}")

# Set the path for class mappings
class_mappings_path = 'class_mappings.json'

# Create the ResNet34 model and load pretrained weights, excluding the final fc layer
def create_resnet_model(num_classes, pretrained_path=None):
    model = models.resnet34(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)  # Replace the final layer for the new task
    
    if pretrained_path:
        # Load the pretrained model, but exclude the final fully connected layer weights
        state_dict = torch.load(pretrained_path)
        # Remove the fc layer weights from the pretrained model
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        # Load the remaining layers
        model.load_state_dict(state_dict, strict=False)  # strict=False ignores missing keys

    return model.to(device)

# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43636918, 0.38563913, 0.34477144],
                         std=[0.29639485, 0.2698132, 0.26158142])
])

# train_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
#     transforms.RandomRotation(degrees=15),
#     transforms.RandomCrop(size=(224, 224), padding=4),
#     transforms.RandomVerticalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.43636918, 0.38563913, 0.34477144],
#                          std=[0.29639485, 0.2698132, 0.26158142])
# ])


val_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43636918, 0.38563913, 0.34477144],
                         std=[0.29639485, 0.2698132, 0.26158142])
])

# Set the dataset folder
dataset_folder = '../data/lfw/'

# Load datasets
dataset = datasets.ImageFolder(root=dataset_folder, transform=train_transform)

# Save class mappings after loading the dataset
store_class_mappings(dataset, class_mappings_path)

# Get the number of classes
num_classes = len(dataset.classes)

# Path to the pretrained model (.pth file)
pretrained_model_path = './resnet_34_pretrained.pth'  # Replace with the correct path

# KFold cross-validator
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Function to train and evaluate the model on a single fold
def train_and_evaluate_fold(train_idx, val_idx, dataset, fold):
    # Create subset for training and validation from indices
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    # Create DataLoader for train and validation sets
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

    # Apply validation transforms directly to the validation DataLoader
    val_loader.dataset.transform = val_transform

    # Initialize model, loss function, and optimizer for each fold
    model = create_resnet_model(num_classes, pretrained_path=pretrained_model_path)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop for 20 epochs for this fold
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            correct_preds += (outputs.argmax(1) == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / len(train_loader.dataset)

        print(f'Fold {fold+1}, Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
    
    # Evaluate the model on the validation set (test equivalent)
    print(f'Fold {fold+1} Validation Results:')
    evaluate_model(model, val_loader)
    return model

# Test loop for evaluation on validation set
def evaluate_model(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    correct_preds = 0
    all_preds = []
    all_labels = []
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():  # Disable gradient calculation for faster inference
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)
            
            preds = outputs.argmax(1)
            correct_preds += (preds == labels).sum().item()
            
            # Store predictions and labels for further analysis
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    # Calculate test loss and accuracy
    test_loss /= len(test_loader.dataset)
    test_acc = correct_preds / len(test_loader.dataset)
    
    # Calculate F1 score
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}, F1 Score: {f1:.4f}')


# Perform cross-validation
for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f'Fold {fold+1}/{num_folds}')
    model = train_and_evaluate_fold(train_idx, val_idx, dataset, fold)

# Save the final trained model after cross-validation
torch.save(model.state_dict(), 'final_resnet34_cv.pth')
print("Cross-validation completed and final model saved as 'final_resnet34_cv.pth'.")


Class mappings saved to class_mappings.json
Fold 1/5
Fold 1, Epoch 1/10 - Loss: 1.9437, Accuracy: 0.3808
Fold 1, Epoch 2/10 - Loss: 1.5665, Accuracy: 0.4812
Fold 1, Epoch 3/10 - Loss: 1.3129, Accuracy: 0.6059
Fold 1, Epoch 4/10 - Loss: 1.1168, Accuracy: 0.6494
Fold 1, Epoch 5/10 - Loss: 0.9416, Accuracy: 0.6837
Fold 1, Epoch 6/10 - Loss: 0.8521, Accuracy: 0.7180
Fold 1, Epoch 7/10 - Loss: 0.7962, Accuracy: 0.7331
Fold 1, Epoch 8/10 - Loss: 0.6868, Accuracy: 0.7824
Fold 1, Epoch 9/10 - Loss: 0.6474, Accuracy: 0.7983
Fold 1, Epoch 10/10 - Loss: 0.5436, Accuracy: 0.8142
Fold 1 Validation Results:
Test Loss: 0.6930, Test Accuracy: 0.7625, F1 Score: 0.7495
Fold 2/5
Fold 2, Epoch 1/10 - Loss: 2.0189, Accuracy: 0.3573
Fold 2, Epoch 2/10 - Loss: 1.6264, Accuracy: 0.4678
Fold 2, Epoch 3/10 - Loss: 1.3810, Accuracy: 0.5607
Fold 2, Epoch 4/10 - Loss: 1.1310, Accuracy: 0.6444
Fold 2, Epoch 5/10 - Loss: 1.0028, Accuracy: 0.6644
Fold 2, Epoch 6/10 - Loss: 0.8531, Accuracy: 0.7138
Fold 2, Epoch 7/10 