Import nessesary libraries

In [1]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import KFold
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns

Data loader

In [2]:
# Custom Dataset class for CIFAR-10 with lazy loading
class CIFAR10Dataset(Dataset):
    def __init__(self, data_path, batch_files, transform=None):
        self.data_path = data_path
        self.batch_files = batch_files
        self.transform = transform
        self.batch_data = None  # Only load the necessary batch when needed
        self.batch_labels = None
        self.batch_index = -1  # Track the currently loaded batch
        self.index_map = []  # Maps dataset index to batch index and in-batch index
        self._create_index_map()

    def _create_index_map(self):
        """Create a map of global indices to batch indices."""
        start_idx = 0
        for batch_num, batch_file in enumerate(self.batch_files):
            with open(os.path.join(self.data_path, batch_file), 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
                batch_size = len(batch[b'labels'])
                self.index_map.extend([(batch_num, i) for i in range(batch_size)])
            start_idx += batch_size

    def _load_batch(self, batch_num):
        """Load a batch given its batch number."""
        batch_file = self.batch_files[batch_num]
        with open(os.path.join(self.data_path, batch_file), 'rb') as f:
            batch = pickle.load(f, encoding='bytes')
            self.batch_data = batch[b'data'].reshape(-1, 3, 32, 32)
            self.batch_labels = batch[b'labels']
        self.batch_index = batch_num  # Update currently loaded batch

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        # Map global index to batch number and in-batch index
        batch_num, in_batch_idx = self.index_map[idx]

        # Load the batch if it's not already loaded
        if batch_num != self.batch_index:
            self._load_batch(batch_num)

        # Fetch image and label from the loaded batch
        image = self.batch_data[in_batch_idx]
        label = self.batch_labels[in_batch_idx]

        # Convert to the expected format (H x W x C)
        image = image.transpose(1, 2, 0)

        if self.transform:
            image = self.transform(image)

        return image, label


# Transformations for CIFAR-10 (ResNet expects 224x224 images)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet-50 requires 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

Training with cross validation

In [3]:
# Custom function to train and evaluate a model with early stopping
def train_and_evaluate(trainloader, validloader, model, criterion, optimizer, num_epochs=10, patience=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # To store loss for each epoch
    train_losses, valid_losses = [], []
    best_valid_loss = float('inf')
    epochs_no_improve = 0  # Counter for early stopping

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_train_samples = 0

        # Train on training data
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs} (Train)')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total_train_samples += labels.size(0)

        # Compute epoch training loss
        epoch_train_loss = running_loss / total_train_samples
        train_losses.append(epoch_train_loss)
        print(f'Training: Loss: {epoch_train_loss:.4f}')

        # Validation loop
        model.eval()
        running_valid_loss = 0.0
        total_valid_samples = 0

        with torch.no_grad():
            progress_bar = tqdm(validloader, desc=f'Epoch {epoch+1}/{num_epochs} (Valid)')
            for inputs, labels in progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_valid_loss += loss.item() * inputs.size(0)
                total_valid_samples += labels.size(0)

        # Compute epoch validation loss
        epoch_valid_loss = running_valid_loss / total_valid_samples
        valid_losses.append(epoch_valid_loss)
        print(f'Validation: Loss: {epoch_valid_loss:.4f}')

        # Check for early stopping
        if epoch_valid_loss < best_valid_loss:
            best_valid_loss = epoch_valid_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Return the training and validation losses
    return train_losses, valid_losses


# K-Fold Cross Validation
def k_fold_cross_validation(dataset, model_class, num_folds=5, num_epochs=10, batch_size=32, patience=5):
    kfold = KFold(n_splits=num_folds, shuffle=True)

    # To store validation performance and learning curves across folds
    all_train_losses, all_valid_losses = [], []

    # K-fold Cross Validation
    for fold, (train_idx, valid_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}/{num_folds}')

        # Subset the data for training and validation
        train_subset = Subset(dataset, train_idx)
        valid_subset = Subset(dataset, valid_idx)

        # Create DataLoaders for this fold
        trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
        validloader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)

        # Create a new instance of the model for each fold
        model = model_class()

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-2)

        # Train and validate
        train_losses, valid_losses = train_and_evaluate(trainloader, validloader, model, criterion, optimizer, num_epochs=num_epochs, patience=patience)

        # Collect learning curves
        all_train_losses.append(train_losses)
        all_valid_losses.append(valid_losses)

    # Return learning curves for further plotting
    return all_train_losses, all_valid_losses

Load data and run training

In [None]:
# Path to the dataset in your Google Drive
data_path = '/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/cifar-10-python/cifar-10-batches-py/'

# Training and test batch file names
train_batches = [f'data_batch_{i}' for i in range(1, 6)]
test_batches = ['test_batch']

# Define the necessary transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet-50 requires 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

# Create Dataset instance for the full training dataset
train_dataset = CIFAR10Dataset(data_path, train_batches, transform=transform)

# Perform 5-fold cross-validation on the dataset with ResNet-50 and save the losses
# Here you only define the batch size once inside the k_fold_cross_validation function
train_losses, valid_losses = k_fold_cross_validation(train_dataset, lambda: resnet50(weights=ResNet50_Weights.DEFAULT), num_folds=5, num_epochs=5, batch_size=128, patience=2)

# Plot the learning curves for the first fold as an example
plot_learning_curves(train_losses[0], valid_losses[0])


Fold 1/5


Epoch 1/10 (Train):  50%|█████     | 157/313 [05:37<05:35,  2.15s/it]

Evaluation on test set

In [None]:
def test_model_with_confusion_matrix(model, testloader):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(enumerate(testloader), total=len(testloader), desc='Testing')
        for i, (inputs, labels) in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            # Append predictions and labels for confusion matrix and metrics
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Generate the confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    # Print the computed metrics
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Plot confusion matrix using seaborn
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[i for i in range(10)], yticklabels=[i for i in range(10)])
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()

# Test the model and display confusion matrix and metrics
test_model_with_confusion_matrix(model, testloader)

Plotting learning curves

In [None]:
def plot_learning_curves(train_losses, valid_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.title('Learning Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


plot_learning_curves(train_losses, valid_losses)