Import necessary libraries

In [1]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import KFold
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns

Data loader

In [2]:
# Custom Dataset class for CIFAR-10 with lazy loading
class CIFAR10Dataset(Dataset):
    def __init__(self, data_path, batch_files, transform=None):
        self.data_path = data_path
        self.batch_files = batch_files
        self.transform = transform
        self.batch_data = None  # Only load the necessary batch when needed
        self.batch_labels = None
        self.batch_index = -1  # Track the currently loaded batch
        self.index_map = []  # Maps dataset index to batch index and in-batch index
        self._create_index_map()

    def _create_index_map(self):
        """Create a map of global indices to batch indices."""
        start_idx = 0
        for batch_num, batch_file in enumerate(self.batch_files):
            with open(os.path.join(self.data_path, batch_file), 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
                batch_size = len(batch[b'labels'])
                self.index_map.extend([(batch_num, i) for i in range(batch_size)])
            start_idx += batch_size

    def _load_batch(self, batch_num):
        """Load a batch given its batch number."""
        batch_file = self.batch_files[batch_num]
        with open(os.path.join(self.data_path, batch_file), 'rb') as f:
            batch = pickle.load(f, encoding='bytes')
            self.batch_data = batch[b'data'].reshape(-1, 3, 32, 32)
            self.batch_labels = batch[b'labels']
        self.batch_index = batch_num  # Update currently loaded batch

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        # Map global index to batch number and in-batch index
        batch_num, in_batch_idx = self.index_map[idx]

        # Load the batch if it's not already loaded
        if batch_num != self.batch_index:
            self._load_batch(batch_num)

        # Fetch image and label from the loaded batch
        image = self.batch_data[in_batch_idx]
        label = self.batch_labels[in_batch_idx]

        # Convert to the expected format (H x W x C)
        image = image.transpose(1, 2, 0)

        if self.transform:
            image = self.transform(image)

        return image, label


# Transformations for CIFAR-10 (ResNet expects 224x224 images)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet-50 requires 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

Training with cross validation

In [3]:
# Custom function to train and evaluate a model with early stopping and validation accuracy
def train_and_evaluate(trainloader, validloader, model, criterion, optimizer, num_epochs=10, patience=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # To store loss and accuracy for each epoch
    train_losses, valid_losses = [], []
    valid_accuracies = []  # To store validation accuracy
    best_valid_loss = float('inf')
    epochs_no_improve = 0  # Counter for early stopping

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_train_samples = 0

        # Train on training data
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs} (Train)')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total_train_samples += labels.size(0)

        # Compute epoch training loss
        epoch_train_loss = running_loss / total_train_samples
        train_losses.append(epoch_train_loss)
        print(f'Training: Loss: {epoch_train_loss:.4f}')

        # Validation loop
        model.eval()
        running_valid_loss = 0.0
        correct_predictions = 0
        total_valid_samples = 0

        with torch.no_grad():
            progress_bar = tqdm(validloader, desc=f'Epoch {epoch+1}/{num_epochs} (Valid)')
            for inputs, labels in progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_valid_loss += loss.item() * inputs.size(0)

                # Get the predicted labels and calculate accuracy
                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()
                total_valid_samples += labels.size(0)

        # Compute epoch validation loss
        epoch_valid_loss = running_valid_loss / total_valid_samples
        valid_losses.append(epoch_valid_loss)
        print(f'Validation: Loss: {epoch_valid_loss:.4f}')

        # Compute validation accuracy
        epoch_valid_acc = correct_predictions / total_valid_samples * 100
        valid_accuracies.append(epoch_valid_acc)
        print(f'Validation: Accuracy: {epoch_valid_acc:.2f}%')

        # Check for early stopping
        if epoch_valid_loss < best_valid_loss:
            best_valid_loss = epoch_valid_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Return the training and validation losses and validation accuracies
    return train_losses, valid_losses, valid_accuracies


# K-Fold Cross Validation
def k_fold_cross_validation(dataset, model_class, num_folds, num_epochs, batch_size, patience):
    kfold = KFold(n_splits=num_folds, shuffle=True)

    # To store validation performance and learning curves across folds
    all_train_losses, all_valid_losses, all_valid_accuracies = [], [], []

    # K-fold Cross Validation
    for fold, (train_idx, valid_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}/{num_folds}')

        # Subset the data for training and validation
        train_subset = Subset(dataset, train_idx)
        valid_subset = Subset(dataset, valid_idx)

        # Create DataLoaders for this fold
        trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
        validloader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)

        # Create a new instance of the model for each fold
        model = model_class()

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-2)

        # Train and validate
        train_losses, valid_losses, valid_accuracies = train_and_evaluate(
            trainloader, validloader, model, criterion, optimizer, num_epochs=num_epochs, patience=patience
        )

        # Collect learning curves
        all_train_losses.append(train_losses)
        all_valid_losses.append(valid_losses)
        all_valid_accuracies.append(valid_accuracies)

    # Return learning curves for further plotting
    return all_train_losses, all_valid_losses, all_valid_accuracies


Load data and run training

In [None]:
# Path to the dataset in Google Drive
data_path = '/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/cifar-10-python/cifar-10-batches-py/'

# Training and test batch file names
train_batches = [f'data_batch_{i}' for i in range(1, 6)]
test_batches = ['test_batch']

# Define the necessary transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet-50 requires 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

# Create Dataset instance for the full training dataset
train_dataset = CIFAR10Dataset(data_path, train_batches, transform=transform)

# Perform 5-fold cross-validation on the dataset with ResNet-50 and save the losses
train_losses, valid_losses = k_fold_cross_validation(train_dataset, lambda: resnet50(weights=ResNet50_Weights.DEFAULT), num_folds=5, num_epochs=5, batch_size=128, patience=2)

Fold 1/5


Epoch 1/5 (Train): 100%|██████████| 313/313 [10:45<00:00,  2.06s/it]


Training: Loss: 0.5320


Epoch 1/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3920
Validation: Accuracy: 86.43%


Epoch 2/5 (Train): 100%|██████████| 313/313 [09:27<00:00,  1.81s/it]


Training: Loss: 0.2238


Epoch 2/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.14it/s]


Validation: Loss: 0.3420
Validation: Accuracy: 88.73%


Epoch 3/5 (Train): 100%|██████████| 313/313 [09:28<00:00,  1.82s/it]


Training: Loss: 0.1494


Epoch 3/5 (Valid): 100%|██████████| 79/79 [00:18<00:00,  4.16it/s]


Validation: Loss: 0.2910
Validation: Accuracy: 90.60%


Epoch 4/5 (Train): 100%|██████████| 313/313 [09:31<00:00,  1.83s/it]


Training: Loss: 0.1156


Epoch 4/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.15it/s]


Validation: Loss: 0.3300
Validation: Accuracy: 89.95%


Epoch 5/5 (Train): 100%|██████████| 313/313 [09:27<00:00,  1.81s/it]


Training: Loss: 0.0934


Epoch 5/5 (Valid): 100%|██████████| 79/79 [00:18<00:00,  4.16it/s]


Validation: Loss: 0.3087
Validation: Accuracy: 90.25%
Early stopping at epoch 5
Fold 2/5


Epoch 1/5 (Train): 100%|██████████| 313/313 [09:34<00:00,  1.83s/it]


Training: Loss: 0.5362


Epoch 1/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.4371
Validation: Accuracy: 85.02%


Epoch 2/5 (Train): 100%|██████████| 313/313 [09:23<00:00,  1.80s/it]


Training: Loss: 0.2266


Epoch 2/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3554
Validation: Accuracy: 87.95%


Epoch 3/5 (Train): 100%|██████████| 313/313 [09:28<00:00,  1.82s/it]


Training: Loss: 0.1544


Epoch 3/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.12it/s]


Validation: Loss: 0.3302
Validation: Accuracy: 89.34%


Epoch 4/5 (Train): 100%|██████████| 313/313 [09:27<00:00,  1.81s/it]


Training: Loss: 0.1201


Epoch 4/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.2937
Validation: Accuracy: 90.54%


Epoch 5/5 (Train): 100%|██████████| 313/313 [09:25<00:00,  1.81s/it]


Training: Loss: 0.0954


Epoch 5/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.12it/s]


Validation: Loss: 0.2606
Validation: Accuracy: 91.75%
Fold 3/5


Epoch 1/5 (Train): 100%|██████████| 313/313 [09:26<00:00,  1.81s/it]


Training: Loss: 0.5276


Epoch 1/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3914
Validation: Accuracy: 86.76%


Epoch 2/5 (Train): 100%|██████████| 313/313 [09:43<00:00,  1.86s/it]


Training: Loss: 0.2216


Epoch 2/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3794
Validation: Accuracy: 87.66%


Epoch 3/5 (Train): 100%|██████████| 313/313 [09:34<00:00,  1.84s/it]


Training: Loss: 0.1527


Epoch 3/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3306
Validation: Accuracy: 89.33%


Epoch 4/5 (Train): 100%|██████████| 313/313 [09:31<00:00,  1.82s/it]


Training: Loss: 0.1149


Epoch 4/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.12it/s]


Validation: Loss: 0.2794
Validation: Accuracy: 91.05%


Epoch 5/5 (Train): 100%|██████████| 313/313 [09:36<00:00,  1.84s/it]


Training: Loss: 0.0932


Epoch 5/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.11it/s]


Validation: Loss: 0.3304
Validation: Accuracy: 90.50%
Fold 4/5


Epoch 1/5 (Train): 100%|██████████| 313/313 [09:41<00:00,  1.86s/it]


Training: Loss: 0.5159


Epoch 1/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3801
Validation: Accuracy: 87.08%


Epoch 2/5 (Train): 100%|██████████| 313/313 [09:31<00:00,  1.83s/it]


Training: Loss: 0.2203


Epoch 2/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.12it/s]


Validation: Loss: 0.3031
Validation: Accuracy: 89.82%


Epoch 3/5 (Train): 100%|██████████| 313/313 [09:34<00:00,  1.84s/it]


Training: Loss: 0.1550


Epoch 3/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3154
Validation: Accuracy: 89.56%


Epoch 4/5 (Train): 100%|██████████| 313/313 [09:31<00:00,  1.82s/it]


Training: Loss: 0.1186


Epoch 4/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.2686
Validation: Accuracy: 91.11%


Epoch 5/5 (Train): 100%|██████████| 313/313 [09:30<00:00,  1.82s/it]


Training: Loss: 0.0933


Epoch 5/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


Validation: Loss: 0.3183
Validation: Accuracy: 90.43%
Fold 5/5


Epoch 1/5 (Train): 100%|██████████| 313/313 [09:32<00:00,  1.83s/it]


Training: Loss: 0.5416


Epoch 1/5 (Valid): 100%|██████████| 79/79 [00:19<00:00,  4.11it/s]


Validation: Loss: 0.3298
Validation: Accuracy: 88.71%


Epoch 2/5 (Train):  20%|██        | 64/313 [01:59<05:31,  1.33s/it]

Evaluation on test set

In [None]:
def test_model_with_confusion_matrix(model, testloader):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(enumerate(testloader), total=len(testloader), desc='Testing')
        for i, (inputs, labels) in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            # Append predictions and labels for confusion matrix and metrics
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Generate the confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    # Print the computed metrics
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Plot confusion matrix using seaborn
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[i for i in range(10)], yticklabels=[i for i in range(10)])
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()

# Test the model and display confusion matrix and metrics
test_model_with_confusion_matrix(model, testloader)

Plotting learning curves

In [None]:
def plot_learning_curves(train_losses, valid_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.plot(valid_accuracies, label='Validation Accuracy')
    plt.title('Learning Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Loss' if 'Loss' in train_losses[0] else 'Accuracy')
    plt.legend()
    plt.show()

plot_learning_curves(train_losses, valid_losses)