In [None]:
# Import necessary libraries
import torch
import torch.nn as nn # Neural network modules (layers, activations, etc.)
import torch.optim as optim  # Optimization algorithms (Adam, SGD, etc.)
import torchvision  # Datasets, models, and transforms for computer vision
import torchvision.transforms as transforms  # Image transformations (augmentation, normalization)
from torch.optim.lr_scheduler import CosineAnnealingLR # Learning rate scheduler
import matplotlib.pyplot as plt  # Plotting training curves


# Define training data augmentation for better generalisation
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.RandomCrop(32, padding=4),   # Random crop with padding
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Colour variations
    transforms.RandomRotation(15), # Slight rotation
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),  # CIFAR10 stats
])

#%% Task 1: Read dataset and create data loaders
# Define test/validation transformations (no augmentation, only normalissation)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

# Create data loaders with 128 batch size
batch_size = 128  # Number of samples processed in one forward/backward pass
# Dataset loading
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

#%% Task 2: Create the model
class ExpertBranch(nn.Module):
    def __init__(self, in_channels, k=4, reduction=16):
        super(ExpertBranch, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.fc1 = nn.Linear(in_channels, in_channels // reduction) # First fully connected layer, to reduce channels
        self.bn1 = nn.BatchNorm1d(in_channels // reduction)
        self.relu = nn.LeakyReLU(0.1) # Better than ReLU for small gradients
        self.fc2 = nn.Linear(in_channels // reduction, k)  # Output K attention weights
        self.softmax = nn.Softmax(dim=1) # Softmax ensures weights sum to 1 (probabilistic)

    def forward(self, x):
        x = self.avgpool(x).view(x.size(0), -1) # Flatten spatial dimensions (batch_size, channels)
        x = self.relu(self.bn1(self.fc1(x))) # Apply FC1 + BN + LeakyReLU
        x = self.softmax(self.fc2(x)) # Return attention weights (batch_size, k)
        return x

class BackboneBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k=4, dropout_prob=0.1):
        super(BackboneBlock, self).__init__()
        self.expert_branch = ExpertBranch(in_channels, k)

        # Enhanced convolutional branches with residual connections
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels), # Normalise activations
                nn.LeakyReLU(0.1), # Non-linearity
                nn.Dropout2d(dropout_prob), # Spatial dropout (regularisation)
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),  # Conv2 (using 3x3 kernel)
                nn.BatchNorm2d(out_channels)  # Final normalisation
            ) for _ in range(k)  # Create the K branches
        ])

        # Shortcut connection for residual learning
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 1x1 conv for channel adjustment
                nn.BatchNorm2d(out_channels))

        self.relu = nn.LeakyReLU(0.1) # Activation after the residual addition
        self.pool = nn.MaxPool2d(2, 2) # Spatial downsampling

    def forward(self, x):
        a = self.expert_branch(x)  # Attention weights [batch_size, k]
        out = sum(a[:, i].view(-1, 1, 1, 1) * self.convs[i](x) for i in range(len(self.convs)))
        out += self.shortcut(x)  # Residual connection
        out = self.relu(out)
        out = self.pool(out)  # Downsample and activate
        return out


class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Enhanced stem with larger initial channels
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),  # 3x3 conv (this preserves the spatial dimensions)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Dropout2d(0.1) # Used for regularisation
        )

        # Deeper backbone with more blocks
        self.backbone = nn.Sequential(
            BackboneBlock(64, 128, k=4), # Eg 64 channels to 128 channels
            BackboneBlock(128, 256, k=4),
            BackboneBlock(256, 512, k=4)
        )

        # Enhanced classifier with more layers and dropout
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),   # Reduces the spatial dims to 1x1 (batch_size, 512, 1, 1)
            nn.Flatten(), # Flatten to (batch_size, 512)
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.3), # High dropout (regularissation)
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Linear(128, 10) # Output 10 classes (CIFAR-10)
        )

    def forward(self, x):
        x = self.stem(x)  # Initial feature extraction
        x = self.backbone(x) # Backbone processing
        x = self.classifier(x) # Completes final classification
        return x

#%% Task 3: Create loss and optimizer
# Detect available device (GPU if available, else use the CPU)model = CNNModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=CNNModel()
model.to(device)
# Loss function (using CrossEntropyLoss for classification)
criterion = nn.CrossEntropyLoss()
# Optimiser (using AdamW with weight decay for L2 regularisation)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=100)

#%% Task 4: The Training script
def train_model(model, trainloader, testloader, criterion, optimizer, scheduler, epochs=35):
    model.train()
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0 # Cumulative loss
        correct_train = 0 # Correct predictions
        total_train = 0 # Total samples

        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad() # Clear gradients
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels) # Computes the losss
            loss.backward() # Backpropagation
            optimizer.step()  # Updates the weights

            running_loss += loss.item() # Accumulates the loss
            _, predicted = torch.max(outputs, 1) # Get the predicted class
            total_train += labels.size(0) # Updates the total samples
            correct_train += (predicted == labels).sum().item()  # Update the number of correct predictions

        scheduler.step() # Update the learning rate
        avg_train_loss = running_loss / len(trainloader)  # Calculates average training loss
        train_accuracy = 100 * correct_train / total_train  # Training accuracy (%)
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)

        # Test evaluation
        model.eval()  # Evaluates model
        test_loss = 0.0
        correct_test = 0
        total_test = 0
        with torch.no_grad(): # Disables the gradient computation
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()

        avg_test_loss = test_loss / len(testloader) # Averaeg validation loss
        test_accuracy = 100 * correct_test / total_test  # Validation accuracy (%)
        test_losses.append(avg_test_loss)
        test_accuracies.append(test_accuracy)

        print(f"Epoch {epoch+1}/{epochs}, "
              f"Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, "
              f"Train Acc: {train_accuracy:.2f}%, Test Acc: {test_accuracy:.2f}%")

    return train_losses, test_losses, train_accuracies, test_accuracies

# Train the model
train_losses, test_losses, train_accuracies, test_accuracies = train_model(
    model, trainloader, testloader, criterion, optimizer, scheduler, epochs=35)

# Plot training curves
plt.figure(figsize=(12, 10))

# Loss plot
plt.subplot(2, 1, 1)
plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', marker='o')
plt.plot(range(1, len(test_losses)+1), test_losses, label='Test Loss', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()

# Accuracy plot
plt.subplot(2, 1, 2)
plt.plot(range(1, len(train_accuracies)+1), train_accuracies, label='Train Accuracy', marker='o')
plt.plot(range(1, len(test_accuracies)+1), test_accuracies, label='Test Accuracy', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Test Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# Final evaluation
def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Final Accuracy on CIFAR-10 test set: {accuracy:.2f}%')
    return accuracy

final_accuracy = evaluate_model(model, testloader)