# Deep Learning Applications: Laboratory #1

In this first laboratory we will work with relatively simple architectures to get a feel for working with Deep Models. This notebook is designed to work with PyTorch, but as I said in the introductory lecture: please feel free to use and experiment with whatever tools you like.

**Important Notes**:
1. Be sure to **document** all of your decisions, as well as your intermediate and final results. Make sure your conclusions and analyses are clearly presented. Don't make us dig into your code or walls of printed results to try to draw conclusions from your code.
2. If you use code from someone else (e.g. Github, Stack Overflow, ChatGPT, etc) you **must be transparent about it**. Document your sources and explain how you adapted any partial solutions to create **your** solution.



## Exercise 1: Warming Up
In this series of exercises I want you to try to duplicate (on a small scale) the results of the ResNet paper:

> [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, CVPR 2016.

We will do this in steps using a Multilayer Perceptron on MNIST.

Recall that the main message of the ResNet paper is that **deeper** networks do not **guarantee** more reduction in training loss (or in validation accuracy). Below you will incrementally build a sequence of experiments to verify this for an MLP. A few guidelines:

+ I have provided some **starter** code at the beginning. **NONE** of this code should survive in your solutions. Not only is it **very** badly written, it is also written in my functional style that also obfuscates what it's doing (in part to **discourage** your reuse!). It's just to get you *started*.
+ These exercises ask you to compare **multiple** training runs, so it is **really** important that you factor this into your **pipeline**. Using [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) is a **very** good idea -- or, even better [Weights and Biases](https://wandb.ai/site).
+ You may work and submit your solutions in **groups of at most two**. Share your ideas with everyone, but the solutions you submit *must be your own*.

First some boilerplate to get you started, then on to the actual exercises!

### Preface: Some code to get you started

What follows is some **very simple** code for training an MLP on MNIST. The point of this code is to get you up and running (and to verify that your Python environment has all needed dependencies).

**Note**: As you read through my code and execute it, this would be a good time to think about *abstracting* **your** model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

In [1]:
# Import standard libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
from functools import reduce
from dataclasses import dataclass, field
from typing import List, Optional, Callable, Dict, Any

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torch.optim import Adam, SGD

# Torchvision imports
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

# Third-party imports
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import wandb
from tqdm import tqdm

# Import moduli locali
from src.models import MLP, ResidualMLP, SimpleCNN, ResidualCNN
from src.trainer import StreamlinedTrainer
from src.data import get_data_transforms, TransformedSubset
from src.utils import compute_accuracy_metrics
from src.config import TrainingConfig

#### Data preparation

Here is some basic dataset loading, validation splitting code to get you started working with MNIST.

In [None]:
# # Standard MNIST transform.
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.1307,), (0.3081,))
# ])

# # Load MNIST train and test.
# ds_train = MNIST(root='./data', train=True, download=True, transform=transform)
# ds_test = MNIST(root='./data', train=False, download=True, transform=transform)

# # Split train into train and validation.
# val_size = 5000
# I = np.random.permutation(len(ds_train))
# ds_val = Subset(ds_train, I[:val_size])
# ds_train = Subset(ds_train, I[val_size:])

#### Boilerplate training and evaluation code

This is some **very** rough training, evaluation, and plotting code. Again, just to get you started. I will be *very* disappointed if any of this code makes it into your final submission.

In [None]:
# from tqdm import tqdm
# from sklearn.metrics import accuracy_score, classification_report

# # Function to train a model for a single epoch over the data loader.
# def train_epoch(model, dl, opt, epoch='Unknown', device='cpu'):
#     model.train()
#     losses = []
#     for (xs, ys) in tqdm(dl, desc=f'Training epoch {epoch}', leave=True):
#         xs = xs.to(device)
#         ys = ys.to(device)
#         opt.zero_grad()
#         logits = model(xs)
#         loss = F.cross_entropy(logits, ys)
#         loss.backward()
#         opt.step()
#         losses.append(loss.item())
#     return np.mean(losses)

# # Function to evaluate model over all samples in the data loader.
# def evaluate_model(model, dl, device='cpu'):
#     model.eval()
#     predictions = []
#     gts = []
#     for (xs, ys) in tqdm(dl, desc='Evaluating', leave=False):
#         xs = xs.to(device)
#         preds = torch.argmax(model(xs), dim=1)
#         gts.append(ys)
#         predictions.append(preds.detach().cpu().numpy())

#     # Return accuracy score and classification report.
#     return (accuracy_score(np.hstack(gts), np.hstack(predictions)),
#             classification_report(np.hstack(gts), np.hstack(predictions), zero_division=0, digits=3))

# # Simple function to plot the loss curve and validation accuracy.
# def plot_validation_curves(losses_and_accs):
#     losses = [x for (x, _) in losses_and_accs]
#     accs = [x for (_, x) in losses_and_accs]
#     plt.figure(figsize=(16, 8))
#     plt.subplot(1, 2, 1)
#     plt.plot(losses)
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.title('Average Training Loss per Epoch')
#     plt.subplot(1, 2, 2)
#     plt.plot(accs)
#     plt.xlabel('Epoch')
#     plt.ylabel('Validation Accuracy')
#     plt.title(f'Best Accuracy = {np.max(accs)} @ epoch {np.argmax(accs)}')

#### A basic, parameterized MLP

This is a very basic implementation of a Multilayer Perceptron. Don't waste too much time trying to figure out how it works -- the important detail is that it allows you to pass in a list of input, hidden layer, and output *widths*. **Your** implementation should also support this for the exercises to come.

In [None]:
# class MLP(nn.Module):
#     def __init__(self, layer_sizes):
#         super().__init__()
#         self.layers = nn.ModuleList([nn.Linear(nin, nout) for (nin, nout) in zip(layer_sizes[:-1], layer_sizes[1:])])

#     def forward(self, x):
#         return reduce(lambda f, g: lambda x: g(F.relu(f(x))), self.layers, lambda x: x.flatten(1))(x)

#### A *very* minimal training pipeline.

Here is some basic training and evaluation code to get you started.

**Important**: I cannot stress enough that this is a **terrible** example of how to implement a training pipeline. You can do better!

In [None]:
# # Training hyperparameters.
# device = 'cuda' if torch.cuda.is_available else 'cpu'
# epochs = 100
# lr = 0.0001
# batch_size = 128

# # Architecture hyperparameters.
# input_size = 28*28
# width = 16
# depth = 2

# # Dataloaders.
# dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=True, num_workers=4)
# dl_val   = torch.utils.data.DataLoader(ds_val, batch_size, num_workers=4)
# dl_test  = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=True, num_workers=4)

# # Instantiate model and optimizer.
# model_mlp = MLP([input_size] + [width]*depth + [10]).to(device)
# opt = torch.optim.Adam(params=model_mlp.parameters(), lr=lr)

# # Training loop.
# losses_and_accs = []
# for epoch in range(epochs):
#     loss = train_epoch(model_mlp, dl_train, opt, epoch, device=device)
#     (val_acc, _) = evaluate_model(model_mlp, dl_val, device=device)
#     losses_and_accs.append((loss, val_acc))

# # And finally plot the curves.
# plot_validation_curves(losses_and_accs)
# print(f'Accuracy report on TEST:\n {evaluate_model(model_mlp, dl_test, device=device)[1]}')

#### My Streamlined Training Pipeline

In [None]:
from typing import List, Optional, Callable, Any, Dict
from torch import nn
from torchvision import transforms
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import wandb
import os
import copy

# Le classi TrainingConfig, get_data_transforms, TransformedSubset e StreamlinedTrainer sono ora importate dai moduli locali.
# Utilizza direttamente queste classi e funzioni per la configurazione, la preparazione dei dati e il training.



### Exercise 1.1: A baseline MLP

Implement a *simple* Multilayer Perceptron to classify the 10 digits of MNIST (e.g. two *narrow* layers). Use my code above as inspiration, but implement your own training pipeline -- you will need it later. Train this model to convergence, monitoring (at least) the loss and accuracy on the training and validation sets for every epoch. Below I include a basic implementation to get you started -- remember that you should write your *own* pipeline!

**Note**: This would be a good time to think about *abstracting* your model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

**Important**: Given the *many* runs you will need to do, and the need to *compare* performance between them, this would **also** be a great point to study how **Tensorboard** or **Weights and Biases** can be used for performance monitoring.

In [None]:
# La classe MLP è ora importata da src.models. Utilizza direttamente MLP per definire i modelli MLP personalizzati.

#### Training Configuration

In [None]:
from src.utils import compute_accuracy_metrics

# Il resto del tuo codice rimane invariato, dato che ora utilizziamo la funzione importata per il calcolo delle metriche.

In [None]:
# Load datasets
ds_train = MNIST(root='./data', train=True, download=True)
ds_test = MNIST(root='./data', train=False, download=True)

# Create streamlined training configuration
config = TrainingConfig(
    num_epochs=60,
    batch_size=256,
    learning_rate=1e-3,
    use_early_stopping=True,
    early_stopping_patience=10,
    validation_split=0.1,
    data_augmentation=["RandomRotation", "RandomAffine"],
    wandb_project="DeepLearningApplication_Lab1"
)

# Define model architecture
model = MLP(
    input_size=28 * 28,
    output_size=10,
    hidden_sizes=[128, 64],
    activation_fn=nn.ReLU,
    dropout_rate=0.1
)

# Create trainer and start training
trainer = StreamlinedTrainer(
    model=model,
    config=config,
    train_dataset=ds_train,
    test_dataset=ds_test,
    compute_metrics_fn=compute_accuracy_metrics
)

print(f"Model: {model}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Start training
final_metrics = trainer.train()

### Exercise 1.2: Adding Residual Connections

Implement a variant of your parameterized MLP network to support **residual** connections. Your network should be defined as a composition of **residual MLP** blocks that have one or more linear layers and add a skip connection from the block input to the output of the final linear layer.

**Compare** the performance (in training/validation loss and test accuracy) of your MLP and ResidualMLP for a range of depths. Verify that deeper networks **with** residual connections are easier to train than a network of the same depth **without** residual connections.

**For extra style points**: See if you can explain by analyzing the gradient magnitudes on a single training batch *why* this is the case.

In [None]:
# Le classi ResidualBlock e ResidualMLP sono ora importate da src.models. Utilizza direttamente ResidualMLP per i modelli MLP con connessioni residue.


#### Training Loop

In [None]:


# Comparison between MLP and ResidualMLP with different depths
def compare_mlp_architectures():
    """Compare standard MLP vs ResidualMLP across different depths."""
    
    # Load datasets
    ds_train = MNIST(root='./data', train=True, download=True)
    ds_test = MNIST(root='./data', train=False, download=True)
    
    # Training configuration
    config = TrainingConfig(
        num_epochs=50,
        batch_size=256,
        learning_rate=1e-3,
        use_early_stopping=True,
        early_stopping_patience=8,
        wandb_project="MLP_vs_ResidualMLP_Comparison"
    )
    
    # Test different architectures
    architectures = [
        {"hidden_sizes": [128, 64], "name": "shallow"},
        {"hidden_sizes": [128, 64, 64, 32], "name": "medium"},
        {"hidden_sizes": [128, 128, 64, 64, 32, 32, 16, 16], "name": "deep"}
    ]
    
    results = {}
    
    for arch in architectures:
        print(f"\n{'='*50}")
        print(f"Testing {arch['name']} architecture")
        print(f"{'='*50}")
        
        # # Standard MLP
        # mlp = MLP(
        #     input_size=28 * 28,
        #     output_size=10,
        #     hidden_sizes=arch["hidden_sizes"],
        #     dropout_rate=0.1
        # )
        
        # print(f"\n--- Standard MLP ---")
        # print(f"Parameters: {sum(p.numel() for p in mlp.parameters()):,}")
        
        # trainer_mlp = StreamlinedTrainer(
        #     model=mlp,
        #     config=config,
        #     train_dataset=ds_train,
        #     test_dataset=ds_test,
        #     compute_metrics_fn=compute_accuracy_metrics
        # )
        
        # mlp_metrics = trainer_mlp.train()
        
        # ResidualMLP (approximate same complexity)
        res_mlp = ResidualMLP(
            input_size=28 * 28,
            output_size=10,
            hidden_sizes=arch["hidden_sizes"],  # Use first hidden size
            # num_blocks=3,
            layers_per_block=2,
            dropout_rate=0.1
        )
        
        print(f"\n--- Residual MLP ---")
        print(f"Parameters: {sum(p.numel() for p in res_mlp.parameters()):,}")
        
        trainer_res = StreamlinedTrainer(
            model=res_mlp,
            config=config,
            train_dataset=ds_train,
            test_dataset=ds_test,
            compute_metrics_fn=compute_accuracy_metrics
        )
        
        res_metrics = trainer_res.train()
        
        # Store results
        results[arch["name"]] = {
            # "mlp": mlp_metrics,
            "residual_mlp": res_metrics
        }
        
        print(f"\n--- Results Summary for {arch['name']} ---")
        # print(f"Standard MLP Test Accuracy: {mlp_metrics['accuracy']:.4f}")
        print(f"Residual MLP Test Accuracy: {res_metrics['accuracy']:.4f}")
    
    return results

# Run the comparison
results = compare_mlp_architectures()

### Exercise 1.3: Rinse and Repeat (but with a CNN)

Repeat the verification you did above, but with **Convolutional** Neural Networks. If you were careful about abstracting your model and training code, this should be a simple exercise. Show that **deeper** CNNs *without* residual connections do not always work better and **even deeper** ones *with* residual connections.

**Hint**: You probably should do this exercise using CIFAR-10, since MNIST is *very* easy (at least up to about 99% accuracy).

**Tip**: Feel free to reuse the ResNet building blocks defined in `torchvision.models.resnet` (e.g. [BasicBlock](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L59) which handles the cascade of 3x3 convolutions, skip connections, and optional downsampling). This is an excellent exercise in code diving.

**Spoiler**: Depending on the optional exercises you plan to do below, you should think *very* carefully about the architectures of your CNNs here (so you can reuse them!).

In [None]:
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    """
    A simple CNN with configurable depth for image classification.
    
    Args:
        in_channels: Number of input channels (1 for MNIST, 3 for CIFAR)
        num_classes: Number of output classes
        depth: Number of convolutional blocks
        initial_channels: Number of channels in first conv layer
    """
    
    def __init__(self, in_channels: int = 3, num_classes: int = 10, depth: int = 3, initial_channels: int = 16):
        super().__init__()
        
        self.depth = depth
        self.initial_channels = initial_channels
        self.in_channels = in_channels
        self.num_classes = num_classes
        
        # Build feature extractor
        layers = []
        current_channels = initial_channels
        
        # Stem layer
        layers.extend([
            nn.Conv2d(in_channels, current_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(current_channels),
            nn.ReLU(inplace=True)
        ])
        
        # Convolutional blocks with downsampling
        for i in range(depth):
            next_channels = current_channels * 2
            layers.extend([
                nn.Conv2d(current_channels, next_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(next_channels),
                nn.ReLU(inplace=True)
            ])
            current_channels = next_channels
        
        self.features = nn.Sequential(*layers)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(current_channels, num_classes)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)
    
    def __str__(self) -> str:
        return f"simplecnn_ch{self.initial_channels}_d{self.depth}_c{self.num_classes}"


In [None]:
# La classe ResidualCNN è ora importata da src.models. Utilizza direttamente ResidualCNN per i modelli CNN con connessioni residue.

#### Training Loop with CIFAR10

In [None]:
def compare_cnn_architectures():
    """Compare SimpleCNN vs ResidualCNN on CIFAR-10."""
    
    # CIFAR-10 specific transforms
    def get_cifar_transforms(is_training: bool = True):
        if is_training:
            return transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=10),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        else:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
    
    # Override transforms for CIFAR-10
    def get_data_transforms_cifar(config, is_training=True):
        return get_cifar_transforms(is_training)
    
    # Load CIFAR-10 datasets  
    ds_train = CIFAR10(root='./data', train=True, download=True)
    ds_test = CIFAR10(root='./data', train=False, download=True)
    
    # Configuration for CIFAR-10
    config = TrainingConfig(
        num_epochs=60,
        batch_size=128,  # Smaller batch size for CIFAR-10
        learning_rate=1e-3,
        use_early_stopping=True,
        early_stopping_patience=10,
        validation_split=0.1,
        data_augmentation=[],  # We handle augmentation in transforms
        wandb_project="CNN_Comparison_CIFAR10"
    )
    
    depths = [2, 3, 4]
    results = {}
    
    for depth in depths:
        print(f"\n{'='*50}")
        print(f"Testing depth {depth}")
        print(f"{'='*50}")
        
        # Simple CNN
        simple_cnn = SimpleCNN(
            in_channels=3,
            num_classes=10,
            depth=depth,
            initial_channels=32
        )
        
        print(f"\n--- Simple CNN (depth {depth}) ---")
        print(f"Parameters: {sum(p.numel() for p in simple_cnn.parameters()):,}")
        
        # Temporarily override transform function for CIFAR-10
        import types
        trainer_simple = StreamlinedTrainer(
            model=simple_cnn,
            config=config,
            train_dataset=ds_train,
            test_dataset=ds_test,
            compute_metrics_fn=compute_accuracy_metrics
        )
        
        # Override the transform creation method for CIFAR-10
        trainer_simple._setup_data_loaders = types.MethodType(
            lambda self, train_dataset: self._setup_cifar_loaders(train_dataset), 
            trainer_simple
        )
        trainer_simple._setup_test_loader = types.MethodType(
            lambda self, test_dataset: self._setup_cifar_test_loader(test_dataset), 
            trainer_simple
        )
        
        # Add CIFAR-specific methods
        def _setup_cifar_loaders(self, train_dataset):
            targets = np.array(train_dataset.targets)
            train_indices, val_indices = train_test_split(
                np.arange(len(targets)), test_size=self.config.validation_split,
                stratify=targets, random_state=self.config.seed
            )
            
            train_subset = TransformedSubset(train_dataset, train_indices, get_cifar_transforms(True))
            val_subset = TransformedSubset(train_dataset, val_indices, get_cifar_transforms(False))
            
            train_loader = DataLoader(train_subset, batch_size=self.config.batch_size, 
                                    shuffle=True, num_workers=self.config.num_workers, 
                                    pin_memory=self.config.pin_memory)
            val_loader = DataLoader(val_subset, batch_size=self.config.batch_size, 
                                  shuffle=False, num_workers=self.config.num_workers, 
                                  pin_memory=self.config.pin_memory)
            
            print(f"Training set size: {len(train_subset)}")
            print(f"Validation set size: {len(val_subset)}")
            return train_loader, val_loader
        
        def _setup_cifar_test_loader(self, test_dataset):
            class TransformedDataset(Dataset):
                def __init__(self, dataset, transform):
                    self.dataset = dataset
                    self.transform = transform
                def __len__(self):
                    return len(self.dataset)
                def __getitem__(self, idx):
                    img, label = self.dataset[idx]
                    return self.transform(img), label
            
            transformed_test = TransformedDataset(test_dataset, get_cifar_transforms(False))
            return DataLoader(transformed_test, batch_size=self.config.batch_size, 
                            shuffle=False, num_workers=self.config.num_workers, 
                            pin_memory=self.config.pin_memory)
        
        trainer_simple._setup_cifar_loaders = types.MethodType(_setup_cifar_loaders, trainer_simple)
        trainer_simple._setup_cifar_test_loader = types.MethodType(_setup_cifar_test_loader, trainer_simple)
        
        # Re-setup dataloaders with CIFAR transforms
        trainer_simple.train_loader, trainer_simple.val_loader = trainer_simple._setup_cifar_loaders(ds_train)
        trainer_simple.test_loader = trainer_simple._setup_cifar_test_loader(ds_test)
        
        simple_metrics = trainer_simple.train()
        
        # Residual CNN
        residual_cnn = ResidualCNN(
            in_channels=3,
            num_classes=10,
            depth=depth,
            initial_channels=32
        )
        
        print(f"\n--- Residual CNN (depth {depth}) ---")
        print(f"Parameters: {sum(p.numel() for p in residual_cnn.parameters()):,}")
        
        trainer_res = StreamlinedTrainer(
            model=residual_cnn,
            config=config,
            train_dataset=ds_train,
            test_dataset=ds_test,
            compute_metrics_fn=compute_accuracy_metrics
        )
        
        # Apply same CIFAR-10 specific setup
        trainer_res._setup_cifar_loaders = types.MethodType(_setup_cifar_loaders, trainer_res)
        trainer_res._setup_cifar_test_loader = types.MethodType(_setup_cifar_test_loader, trainer_res)
        trainer_res.train_loader, trainer_res.val_loader = trainer_res._setup_cifar_loaders(ds_train)
        trainer_res.test_loader = trainer_res._setup_cifar_test_loader(ds_test)
        
        res_metrics = trainer_res.train()
        
        # Store results
        results[depth] = {
            "simple_cnn": simple_metrics,
            "residual_cnn": res_metrics
        }
        
        print(f"\n--- Results Summary for depth {depth} ---")
        print(f"Simple CNN Test Accuracy: {simple_metrics['accuracy']:.4f}")
        print(f"Residual CNN Test Accuracy: {res_metrics['accuracy']:.4f}")
    
    return results

# Run the comparison (uncomment to execute)
results = compare_cnn_architectures()

-----
## Exercise 2: Choose at Least One

Below are **three** exercises that ask you to deepen your understanding of Deep Networks for visual recognition. You must choose **at least one** of the below for your final submission -- feel free to do **more**, but at least **ONE** you must submit. Each exercise is designed to require you to dig your hands **deep** into the guts of your models in order to do new and interesting things.

**Note**: These exercises are designed to use your small, custom CNNs and small datasets. This is to keep training times reasonable. If you have a decent GPU, feel free to use pretrained ResNets and larger datasets (e.g. the [Imagenette](https://pytorch.org/vision/0.20/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette) dataset at 160px).

### Exercise 2.2: *Distill* the knowledge from a large model into a smaller one
In this exercise you will see if you can derive a *small* model that performs comparably to a larger one on CIFAR-10. To do this, you will use [Knowledge Distillation](https://arxiv.org/abs/1503.02531):

> Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the Knowledge in a Neural Network, NeurIPS 2015.

To do this:
1. Train one of your best-performing CNNs on CIFAR-10 from Exercise 1.3 above. This will be your **teacher** model.
2. Define a *smaller* variant with about half the number of parameters (change the width and/or depth of the network). Train it on CIFAR-10 and verify that it performs *worse* than your **teacher**. This small network will be your **student** model.
3. Train the **student** using a combination of **hard labels** from the CIFAR-10 training set (cross entropy loss) and **soft labels** from predictions of the **teacher** (Kulback-Leibler loss between teacher and student).

Try to optimize training parameters in order to maximize the performance of the student. It should at least outperform the student trained only on hard labels in Setp 2.

**Tip**: You can save the predictions of the trained teacher network on the training set and adapt your dataloader to provide them together with hard labels. This will **greatly** speed up training compared to performing a forward pass through the teacher for each batch of training.

## Exercise 2.2: Knowledge Distillation

In this exercise, we'll implement knowledge distillation to transfer knowledge from a large **teacher** model to a smaller **student** model.

**Steps:**
1. Train a ResidualCNN with depth 4 as the **teacher** model
2. Create a smaller **student** model with ~half the parameters  
3. Train the student using both hard labels and soft labels from the teacher

In [3]:
### Step 1: Train the Teacher Model (ResidualCNN depth=4)

cifar10_train = CIFAR10(root='./data', train=True, download=True)
cifar10_test = CIFAR10(root='./data', train=False, download=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create teacher model
teacher_model = ResidualCNN(depth=4, num_classes=10)
print(f"Teacher model parameters: {sum(p.numel() for p in teacher_model.parameters()):,}")

# Create configuration for teacher training
teacher_config = TrainingConfig(
    num_epochs=15,
    batch_size=128,
    learning_rate=0.001,
    weight_decay=1e-4,
    loss_fn=nn.CrossEntropyLoss(),
    device=device,
    use_wandb=True,
    wandb_project="dl-lab1-knowledge-distillation",
    use_early_stopping=True,
    early_stopping_patience=5,
    early_stopping_metric="accuracy",
    maximize_metric=True,
    output_dir="models/"
)

# Train teacher model
print("Training teacher model...")
teacher_trainer = StreamlinedTrainer(
    model=teacher_model,
    config=teacher_config,
    train_dataset=cifar10_train,
    test_dataset=cifar10_test
)

teacher_results = teacher_trainer.train()
print(f"\nTeacher model test accuracy: {teacher_results['accuracy']:.4f}")

# Save teacher model and store results
teacher_path = teacher_trainer.save_model("teacher")
teacher_model_state = teacher_model.state_dict().copy()
teacher_accuracy = teacher_results['accuracy']

Teacher model parameters: 2,793,114
Training teacher model...
Training set size: 45000
Validation set size: 5000


[34m[1mwandb[0m: Currently logged in as: [33mvincenzo-civale[0m ([33mvincenzo-civale-universi-degli-studi-di-firenze[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting training for 15 epochs
Device: cuda
Batch size: 128

Epoch 1/15


                                                                        

Train Loss: 1.5340, Val Metrics: {'loss': 1.2635868400335313, 'accuracy': 0.5462}
✓ Best model updated!

Epoch 2/15


                                                                        

Train Loss: 1.2023, Val Metrics: {'loss': 1.074976746737957, 'accuracy': 0.6298}
✓ Best model updated!

Epoch 3/15


                                                                        

Train Loss: 1.0197, Val Metrics: {'loss': 0.9488953307271004, 'accuracy': 0.6704}
✓ Best model updated!

Epoch 4/15


                                                                        

Train Loss: 0.9014, Val Metrics: {'loss': 0.7755203813314437, 'accuracy': 0.7304}
✓ Best model updated!

Epoch 5/15


                                                                        

Train Loss: 0.8267, Val Metrics: {'loss': 0.7082339152693748, 'accuracy': 0.7504}
✓ Best model updated!

Epoch 6/15


                                                                        

Train Loss: 0.7570, Val Metrics: {'loss': 0.6670993015170097, 'accuracy': 0.7724}
✓ Best model updated!

Epoch 7/15


                                                                        

Train Loss: 0.7132, Val Metrics: {'loss': 0.6120938830077648, 'accuracy': 0.7834}
✓ Best model updated!

Epoch 8/15


                                                                        

Train Loss: 0.6761, Val Metrics: {'loss': 0.6535413131117821, 'accuracy': 0.777}
No improvement for 1 epochs

Epoch 9/15


                                                                        

Train Loss: 0.6389, Val Metrics: {'loss': 0.5610129199922085, 'accuracy': 0.8074}
✓ Best model updated!

Epoch 10/15


                                                                        

Train Loss: 0.6075, Val Metrics: {'loss': 0.6785904765129089, 'accuracy': 0.7712}
No improvement for 1 epochs

Epoch 11/15


                                                                        

Train Loss: 0.5799, Val Metrics: {'loss': 0.5659338288009167, 'accuracy': 0.8034}
No improvement for 2 epochs

Epoch 12/15


                                                                        

Train Loss: 0.5631, Val Metrics: {'loss': 0.5993236765265465, 'accuracy': 0.8004}
No improvement for 3 epochs

Epoch 13/15


                                                                        

Train Loss: 0.5373, Val Metrics: {'loss': 0.5487134583294392, 'accuracy': 0.8142}
✓ Best model updated!

Epoch 14/15


                                                                        

Train Loss: 0.5211, Val Metrics: {'loss': 0.5354428112506866, 'accuracy': 0.8228}
✓ Best model updated!

Epoch 15/15


                                                                        

Train Loss: 0.5019, Val Metrics: {'loss': 0.5323229275643826, 'accuracy': 0.8208}
No improvement for 1 epochs
Loading best model weights


                                                                


Final Test Metrics: {'loss': 0.5461378218252447, 'accuracy': 0.8164}

Teacher model test accuracy: 0.8164
Model saved to models/rescnn_ch16_d4_c10_teacher.pth


In [4]:
### Step 2: Create and Train Student Model (smaller version)

# Create student model with ~half the parameters
student_model = ResidualCNN(depth=2,  num_classes=10)  # Reduced depth and width
print(f"Student model parameters: {sum(p.numel() for p in student_model.parameters()):,}")
print(f"Parameter reduction: {sum(p.numel() for p in student_model.parameters()) / sum(p.numel() for p in teacher_model.parameters()):.2%}")

# Train student model without distillation
student_config = TrainingConfig(
    num_epochs=15,
    batch_size=128,
    learning_rate=0.001,
    weight_decay=1e-4,
    loss_fn=nn.CrossEntropyLoss(),
    device=device,
    use_wandb=True,
    wandb_project="dl-lab1-knowledge-distillation",
    use_early_stopping=True,
    early_stopping_patience=5,
    early_stopping_metric="accuracy",
    maximize_metric=True,
    output_dir="models/"
)

print("Training student model (without distillation)...")
student_trainer = StreamlinedTrainer(
    model=student_model,
    config=student_config,
    train_dataset=cifar10_train,
    test_dataset=cifar10_test
)

student_baseline_results = student_trainer.train()
print(f"\nStudent baseline test accuracy: {student_baseline_results['accuracy']:.4f}")

# Save student baseline
student_baseline_path = student_trainer.save_model("student_baseline")
student_baseline_accuracy = student_baseline_results['accuracy']

Student model parameters: 165,914
Parameter reduction: 5.94%
Training student model (without distillation)...
Training set size: 45000
Validation set size: 5000


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
test/accuracy,▁
test/loss,▁
train/train_loss,██▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▃▄▆▆▇▇▇█▇█▇███
val/loss,█▆▅▃▃▂▂▂▁▂▁▂▁▁▁

0,1
test/accuracy,0.8164
test/loss,0.54614
train/train_loss,0.49944
val/accuracy,0.8208
val/loss,0.53232


Starting training for 15 epochs
Device: cuda
Batch size: 128

Epoch 1/15


                                                                        

Train Loss: 1.5143, Val Metrics: {'loss': 1.2407567590475082, 'accuracy': 0.5504}
✓ Best model updated!

Epoch 2/15


                                                                        

Train Loss: 1.1961, Val Metrics: {'loss': 1.1029451370239258, 'accuracy': 0.6088}
✓ Best model updated!

Epoch 3/15


                                                                        

Train Loss: 1.0453, Val Metrics: {'loss': 1.0837183237075805, 'accuracy': 0.613}
✓ Best model updated!

Epoch 4/15


                                                                        

Train Loss: 0.9569, Val Metrics: {'loss': 0.9742089167237282, 'accuracy': 0.6594}
✓ Best model updated!

Epoch 5/15


                                                                        

Train Loss: 0.8898, Val Metrics: {'loss': 0.8519542336463928, 'accuracy': 0.7024}
✓ Best model updated!

Epoch 6/15


                                                                        

Train Loss: 0.8233, Val Metrics: {'loss': 0.7638009175658226, 'accuracy': 0.733}
✓ Best model updated!

Epoch 7/15


                                                                        

Train Loss: 0.7851, Val Metrics: {'loss': 0.8307042688131332, 'accuracy': 0.7124}
No improvement for 1 epochs

Epoch 8/15


                                                                        

Train Loss: 0.7479, Val Metrics: {'loss': 0.7266128897666931, 'accuracy': 0.7528}
✓ Best model updated!

Epoch 9/15


                                                                        

Train Loss: 0.7125, Val Metrics: {'loss': 0.6902000844478607, 'accuracy': 0.7646}
✓ Best model updated!

Epoch 10/15


                                                                        

Train Loss: 0.6795, Val Metrics: {'loss': 0.86166403144598, 'accuracy': 0.7192}
No improvement for 1 epochs

Epoch 11/15


                                                                        

Train Loss: 0.6551, Val Metrics: {'loss': 0.694909043610096, 'accuracy': 0.766}
✓ Best model updated!

Epoch 12/15


                                                                        

Train Loss: 0.6367, Val Metrics: {'loss': 0.691471978276968, 'accuracy': 0.7664}
✓ Best model updated!

Epoch 13/15


                                                                        

Train Loss: 0.6202, Val Metrics: {'loss': 0.5978748992085456, 'accuracy': 0.7922}
✓ Best model updated!

Epoch 14/15


                                                                        

Train Loss: 0.6011, Val Metrics: {'loss': 0.6620557852089405, 'accuracy': 0.769}
No improvement for 1 epochs

Epoch 15/15


                                                                        

Train Loss: 0.5846, Val Metrics: {'loss': 0.625201053917408, 'accuracy': 0.7804}
No improvement for 2 epochs
Loading best model weights


                                                                


Final Test Metrics: {'loss': 0.6288939273055596, 'accuracy': 0.7833}

Student baseline test accuracy: 0.7833
Model saved to models/rescnn_ch16_d2_c10_student_baseline.pth


In [5]:
### Step 3: Knowledge Distillation Training

# Create a custom trainer for knowledge distillation
class KnowledgeDistillationTrainer(StreamlinedTrainer):
    def __init__(self, student_model, teacher_model, config, train_dataset, test_dataset, 
                 temperature=3.0, alpha=0.7, compute_metrics_fn=None):
        # Initialize parent with student model
        super().__init__(student_model, config, train_dataset, test_dataset, compute_metrics_fn)
        
        
        self.teacher_model = teacher_model
        self.teacher_model.to(config.device)
        self.teacher_model.eval()  # Teacher is always in eval mode
        
        # Distillation parameters
        self.temperature = temperature
        self.alpha = alpha  # Weight for distillation loss
        
        # Loss functions
        self.hard_loss_fn = nn.CrossEntropyLoss()
        self.soft_loss_fn = nn.KLDivLoss(reduction='batchmean')
        
    def train_epoch(self):
        self.model.train()
        self.teacher_model.eval()
        
        total_loss = 0.0
        total_hard_loss = 0.0
        total_soft_loss = 0.0
        num_batches = 0
        
        from tqdm import tqdm
        progress_bar = tqdm(self.train_loader, desc="Distillation Training", leave=False)
        
        for batch in progress_bar:
            inputs, labels = batch
            inputs, labels = inputs.to(self.config.device), labels.to(self.config.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass through student
            student_outputs = self.model(inputs)
            
            # Forward pass through teacher (no gradients)
            with torch.no_grad():
                teacher_outputs = self.teacher_model(inputs)
            
            # Calculate losses using the distillation_loss function from models.py
            loss, hard_loss, soft_loss = distillation_loss(
                student_outputs, teacher_outputs, labels,
                temperature=self.temperature, alpha=self.alpha
            )
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            total_hard_loss += hard_loss.item()
            total_soft_loss += soft_loss.item()
            num_batches += 1
            self.global_step += 1
            
            avg_total_loss = total_loss / num_batches
            avg_hard_loss = total_hard_loss / num_batches
            avg_soft_loss = total_soft_loss / num_batches
            
            progress_bar.set_postfix({
                "total": f"{avg_total_loss:.4f}",
                "hard": f"{avg_hard_loss:.4f}", 
                "soft": f"{avg_soft_loss:.4f}"
            })
            
            if self.config.logging_steps > 0 and self.global_step % self.config.logging_steps == 0:
                self._log_metrics({
                    "train_loss": avg_total_loss,
                    "train_hard_loss": avg_hard_loss,
                    "train_soft_loss": avg_soft_loss
                }, "train/")
        
        return total_loss / num_batches

print("Knowledge Distillation Trainer created!")

Knowledge Distillation Trainer created!


In [7]:
# Create a fresh student model for distillation training
student_distilled = ResidualCNN(depth=2, num_classes=10)

# Load the trained teacher model
teacher_for_distillation = ResidualCNN(depth=4,num_classes=10)
teacher_for_distillation.load_state_dict(teacher_model_state)

# Configuration for distillation training
distillation_config = TrainingConfig(
    num_epochs=20,  # More epochs for distillation
    batch_size=128,
    learning_rate=0.001,
    weight_decay=1e-4,
    loss_fn=nn.CrossEntropyLoss(),  # This won't be used, but needed for parent class
    device=device,
    use_wandb=True,
    wandb_project="dl-lab1-knowledge-distillation",
    use_early_stopping=True,
    early_stopping_patience=7,
    early_stopping_metric="accuracy",
    maximize_metric=True,
    output_dir="models/"
)

# Train student with knowledge distillation
print("Training student model with knowledge distillation...")
distillation_trainer = KnowledgeDistillationTrainer(
    student_model=student_distilled,
    teacher_model=teacher_for_distillation,
    config=distillation_config,
    train_dataset=cifar10_train,
    test_dataset=cifar10_test,
    temperature=3.0,  # Temperature for softmax
    alpha=0.7  # Weight for distillation loss (0.7 distillation + 0.3 hard labels)
)

student_distilled_results = distillation_trainer.train()
print(f"\nStudent (distilled) test accuracy: {student_distilled_results['accuracy']:.4f}")

# Save distilled student model
distilled_path = distillation_trainer.save_model("student_distilled")
student_distilled_accuracy = student_distilled_results['accuracy']

Training student model with knowledge distillation...
Training set size: 45000
Validation set size: 5000


0,1
test/accuracy,▁
test/loss,▁
train/train_loss,█▇▇▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▃▃▄▅▆▆▇▇▆▇▇█▇█
val/loss,█▆▆▅▄▃▄▂▂▄▂▂▁▂▁

0,1
test/accuracy,0.7833
test/loss,0.62889
train/train_loss,0.58367
val/accuracy,0.7804
val/loss,0.6252


Starting training for 20 epochs
Device: cuda
Batch size: 128

Epoch 1/20


                                                              

NameError: name 'distillation_loss' is not defined

In [None]:
### Step 4: Compare Results

print("="*60)
print("KNOWLEDGE DISTILLATION RESULTS")
print("="*60)
print(f"Teacher Model (ResidualCNN depth=4):     {teacher_accuracy:.4f}")
print(f"Student Baseline (ResidualCNN depth=2):  {student_baseline_accuracy:.4f}")
print(f"Student Distilled (ResidualCNN depth=2): {student_distilled_accuracy:.4f}")
print("="*60)

improvement = student_distilled_accuracy - student_baseline_accuracy
print(f"Improvement from Knowledge Distillation: {improvement:.4f} ({improvement*100:.2f}%)")

# Calculate parameter reduction
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_distilled.parameters())
print(f"Parameter reduction: {student_params/teacher_params:.2%} of teacher size")

# Visualize results
import matplotlib.pyplot as plt

models = ['Teacher\n(ResNet-4)', 'Student\nBaseline', 'Student\nDistilled']
accuracies = [teacher_accuracy, student_baseline_accuracy, student_distilled_accuracy]
colors = ['#2E8B57', '#CD5C5C', '#4169E1']

plt.figure(figsize=(10, 6))
bars = plt.bar(models, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=1)

# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.002,
             f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

plt.ylabel('Test Accuracy', fontsize=12)
plt.title('Knowledge Distillation Results on CIFAR-10', fontsize=14, fontweight='bold')
plt.ylim(0, max(accuracies) + 0.05)
plt.grid(axis='y', alpha=0.3)

# Add parameter count annotations
param_counts = [teacher_params, student_params, student_params]
for i, (bar, params) in enumerate(zip(bars, param_counts)):
    plt.text(bar.get_x() + bar.get_width()/2., 0.02,
             f'{params:,}\nparams', ha='center', va='bottom', 
             fontsize=9, style='italic')

plt.tight_layout()
plt.show()

# Summary
if improvement > 0:
    print(f"\nKnowledge distillation improved student performance by {improvement:.4f}!")
    print(f"   The distilled student achieved {student_distilled_accuracy:.4f} vs {student_baseline_accuracy:.4f} baseline")
else:
    print(f"\nThe distilled student did not outperform the baseline.")
    print("   Try adjusting hyperparameters (temperature, alpha, learning rate, epochs)")
    
print(f"\n📊 Knowledge transfer efficiency: {(student_distilled_accuracy/teacher_accuracy)*100:.1f}%")
print(f"   (Student achieved {(student_distilled_accuracy/teacher_accuracy)*100:.1f}% of teacher performance with {student_params/teacher_params*100:.1f}% of parameters)")

### Knowledge Distillation Analysis

**What we implemented:**

1. **Teacher Model**: ResidualCNN with depth=4, width=64 (~larger model)
2. **Student Model**: ResidualCNN with depth=2, width=32 (~50% fewer parameters)
3. **Knowledge Distillation**: Combined loss using:
   - **Hard loss**: Cross-entropy with true labels (weight = 1-α)
   - **Soft loss**: KL-divergence with teacher predictions (weight = α)
   - **Temperature**: Softens probability distributions for better knowledge transfer

**Key Hyperparameters:**
- **Temperature (T=3.0)**: Controls softness of probability distributions
- **Alpha (α=0.7)**: Balance between distillation loss and hard labels
- **Extended training**: More epochs to allow proper knowledge transfer

**Expected Outcome:**
The distilled student should outperform the baseline student, demonstrating that soft labels from the teacher provide additional information beyond hard labels alone.

**Why Knowledge Distillation Works:**
- Teacher provides "dark knowledge" - information about wrong classes
- Soft targets reveal similarities between classes
- Student learns from teacher's feature representations
- More efficient than training large models from scratch