## 3. Training a NN

In this part of the tutrorial, we'll be training a simple MLP on MNIST.
We'll go over the basic components for the training part:
- Datasets
- DataLoaders
- The basic training loop

In [None]:
import torch
from torch import nn

# tqdm adds support for progress bars
from tqdm import tqdm

# torchvision provides support for computer vision (datasets, transformations, models,...)
import torchvision
from torchvision import transforms as T

# torcheval provides support for evaluation metrics
from torcheval import metrics

# timm is a HuggingFace library providing a large collection of pre-trained image models (>> torchvision)
import timm

# extra libraries
from PIL import Image
import matplotlib.pyplot as plt
import os

We'll first download the MNIST dataset and then define the model and the training loop.

In [None]:
mnist_train = torchvision.datasets.MNIST(root="data", train=True, download=True)
mnist_test = torchvision.datasets.MNIST(root="data", train=False, download=True)

Datasets to be used in PyTorch should be subclasses of `torch.utils.data.Dataset`.

In [None]:
isinstance(mnist_train, torch.utils.data.Dataset)

The `Dataset` provides the basic interface to access the data and defines data augmentation/preprocessing steps:
- the `transform` attribute provides the information about the preprocessing steps to be applied to the data
- `__len__` should return the size of the dataset
- `__getitem__` should return the item at the given index **after applying preprocessing**

![](img/dataset.png)

In [None]:
mnist_train.transform is None

Right now, the dataset essentially contains basic images.

In [None]:
img, label = mnist_train[0]
img = img.copy()

img, label

In [None]:
plt.imshow(img, cmap="gray")

We can add some basic transformations to the dataset so it can be elaborated by the dataset:

- `ToTensor` converts the image to a tensor and normalizes it in the 0-1 range. This is a **preprocessing step** since it's not deterministic. 
  - We can very well convert the whole dataset to tensors beforehand. **Question**: What prevents us from doing this?
- `RandomAffine` applies a random affine transformation (rotation + translation + scaling) to the image. This is a **data augmentation step** since it's non-deterministic and we want to apply it on the fly.

In [None]:
transforms = T.Compose([
    T.ToTensor(),
    T.RandomAffine(degrees=10, scale=(0.9, 1.1)),
])

mnist_train.transform = transforms

In [None]:
img2, label = mnist_train[0]

# plot 2 images side by side
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img, cmap="gray")
axs[1].imshow((img2.permute(1, 2, 0) * 255).numpy().astype("uint8"), cmap="gray")
plt.title("Transformed")

We also need to add the `ToTensor` transformation to the `test_transform` since we want to convert the images to tensors before feeding them to the model.

In [None]:
mnist_test.transform = T.ToTensor()

To train a model using SGD, we need to pack the data into batches. This is done using `DataLoader` which is a subclass of `torch.utils.data.DataLoader`.

In [None]:
trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=False)


A DataLoader implements the iterator pattern in a **lazy way** (using a generator):
- We cannot access the batches directly (e.g., `trainloader[0]` does not work)
- We create a dataset by iterating over the DataLoader or by calling the `next` method on `iter(dataloader)`

In [None]:
for i, batch in enumerate(trainloader):
    print(f"Batch {i+1}", batch[0].shape, batch[1].shape)
    

Let's now define the model. We'll be using a simple MLP with 3 hidden layers, relu activations, and batch normalization.

In [None]:
mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 16),
    nn.ReLU(),
    nn.BatchNorm1d(16),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.BatchNorm1d(16),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.BatchNorm1d(16),
    nn.Linear(16, 10)
)

Additionally, our model needs an **optimizer** to update the weights. We'll use the `torch.optim.SGD` optimizer.

In [None]:
optimizer = torch.optim.SGD(mlp.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)

We can now build the training loop.

We define a number of epochs for training. In each epoch, we'll *regenerate* the trainloader to redo the batches and apply the random transformations to the images.

```python
for epoch in range(epochs):
    for data, labels in trainloader:
        ...
```

Within the inner loop, we will need to do the following things:
- Zero out the gradients from the previous iteration
- Do the forward pass
- Compute the loss
- Do the backward pass
- Update the weights



In [None]:
epochs = 3

In [None]:
mlp.train() # IMPORTANT: this sets the model to training mode --- useful for batchnorm and dropout

for epoch in range(epochs):
    for data, labels in trainloader:
        # inner loop
        pass

It is often useful to compute some other metrics while the model is training.

We can use `torcheval` functionalities to keep track of accuracy and loss.

In addition, we can use `tqdm` to display a progress bar.

We will also anneal the learning rate at the end of each epoch.

In [None]:
mlp.train() # IMPORTANT: this sets the model to training mode --- useful for batchnorm and dropout

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) # will anneal LR by 0.1 each time scheduler.step() is called

for epoch in range(epochs):
    accuracy_counter = metrics.MulticlassAccuracy()
    loss_counter = metrics.Mean()

    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")
    for data, labels in progress_bar:
        # some code here

        accuracy_counter.update(predictions, labels)
        loss_counter.update(loss, weight=data.size(0))

        progress_bar.set_postfix(
            loss=loss_counter.compute().item(),
            accuracy=accuracy_counter.compute().item()
        )
    scheduler.step() # anneal LR by 0.1


    


Test evaluation is done in a similar way to the training loop.

In [None]:
mlp.eval() # IMPORTANT: this sets the model to evaluation mode --- useful for batchnorm and dropout

accuracy_counter = metrics.MulticlassAccuracy()
loss_counter = metrics.Mean()

progress_bar = tqdm(testloader, desc=f"Eval 1/1")
for data, labels in progress_bar:
    with torch.no_grad(): # force no gradients
        # some code here


        accuracy_counter.update(predictions, labels)
        loss_counter.update(loss, weight=data.size(0))

        progress_bar.set_postfix(
            loss=loss_counter.compute().item(),
            accuracy=accuracy_counter.compute().item()
        )

Let's bundle together all the components to train and eval the model

In [None]:
def train(model, optimizer, scheduler, trainloader, epochs):
    model.train()
    for epoch in range(epochs):
        accuracy_counter = metrics.MulticlassAccuracy()
        loss_counter = metrics.Mean()

        progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")
        for data, labels in progress_bar:

            accuracy_counter.update(predictions, labels)
            loss_counter.update(loss, weight=data.size(0))

            progress_bar.set_postfix(
                loss=loss_counter.compute().item(),
                accuracy=accuracy_counter.compute().item()
            )
        scheduler.step()

def evaluate(model, testloader):
    model.eval()
    accuracy_counter = metrics.MulticlassAccuracy()
    loss_counter = metrics.Mean()

    progress_bar = tqdm(testloader, desc="Eval 1/1")
    for data, labels in progress_bar:
        with torch.no_grad():
            

            progress_bar.set_postfix(
                loss=loss_counter.compute().item(),
                accuracy=accuracy_counter.compute().item()
            )

### Saving and loading a model

We can save the model using `torch.save(model.state_dict(), save_path)` and load it using `model.load_state_dict(torch.load(save_path))`.

The `state_dict` is a dictionary containing the model's parameters.


In [None]:
mlp.state_dict()

In [None]:
save_path = "weights/mlp_mnist.pth"

os.makedirs(os.path.dirname(save_path), exist_ok=True)

torch.save(mlp.state_dict(), "weights/mlp_mnist.pth")

In [None]:
mlp.load_state_dict(torch.load("weights/mlp_mnist.pth"))

Notice: also optimizers, schedulers, and metrics have a state_dict

In [None]:
optimizer.state_dict()

In [None]:
scheduler.state_dict()

In [None]:
accuracy_counter.state_dict()

#### Saving checkpoints

You can also save checkpoints during training. This is useful in case the training is interrupted and you want to resume from the last checkpoint.

The checkpoint should contain all the information needed to resume training:
- The epoch number
  - In case you're saving the checkpoint mid-epoch, you must save the iteration number as well, and also the partial metrics
- The model's state_dict
- The optimizer's state_dict
- If you're using a learning rate scheduler, you should save its state_dict as well

### Transfer learning

We can do transfer learning by loading a pretrained model and changing the last layer to match the number of classes in our dataset.

`timm` and `torchvision` provide a number of pretrained models that can be used for transfer learning.

We can create the models by passing the `pretrained=True` argument and the num_classes argument to match the number of classes in our dataset.

In [None]:
model_pretrained = timm.create_model("resnet18", pretrained=True, num_classes=10)

model_pretrained

In [None]:
model_scratch = timm.create_model("resnet18", pretrained=False, num_classes=10)

We additionaly need to modify our dataset since ResNet18 is a CNN and expects a 3-channel input.

In [None]:
mnist_train.transform = T.Compose([
    T.ToTensor(),
    T.RandomAffine(degrees=10, scale=(0.9, 1.1)),
    T.Lambda(lambda x: x.repeat(3, 1, 1)), # repeat the single channel 3 times to get 3 channels
    
])

mnist_test.transform = T.Compose([
    T.ToTensor(),
    T.Lambda(lambda x: x.repeat(3, 1, 1)),
])

In [None]:
num_epochs = 1

optimizer_pretrained = torch.optim.SGD(model_pretrained.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
scheduler_pretrained = torch.optim.lr_scheduler.StepLR(optimizer_pretrained, step_size=1, gamma=0.1)

train(model_pretrained, optimizer_pretrained, scheduler_pretrained, trainloader, num_epochs)

optimizer_scratch = torch.optim.SGD(model_scratch.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
scheduler_scratch = torch.optim.lr_scheduler.StepLR(optimizer_scratch, step_size=1, gamma=0.1)

train(model_scratch, optimizer_scratch, scheduler_scratch, trainloader, num_epochs)

We can additionally freeze the hidden layers of the pretrained model by setting `requires_grad=False` for the parameters of the hidden layers.

In [None]:
for param in model_pretrained.parameters():
    param.requires_grad = False

# Unfreeze the last layer
for param in model_pretrained.fc.parameters():
    param.requires_grad = True

# you can proceed to retrain the model. Note: you have to reload the parameters!

### Using acceleration

We can use the GPU to accelerate the training process.

We can move the model and the data to the GPU and CPU using the `to` method.

`to` accepts a device argument which can be either `cuda` or `cpu`.

In case of multiple GPUs available, we can force the usage of a specific GPU by specifying the device number (e.g., `cuda:0` for first GPU).

In [None]:
x = torch.randn(1, 3, 224, 224)

if torch.cuda.is_available():
    x = x.to("cuda")
else:
    raise RuntimeError("CUDA not available!")

For training/evaluating a model on CUDA, we need to shift to the GPU the following:
- The model (`model.to(device)`)
- The data (`data = data.to(device)`)
- The labels (`labels = labels.to(device)`)

In addition, for single-node multi-GPU training, we can use `torch.nn.DataParallel` to parallelize the training process.

We can wrap the model using `torch.nn.DataParallel(model)` and PyTorch will automatically distribute the batches to the GPUs.

```python

model = nn.DataParallel(model)
```

#### Using mixed precision

We can use mixed precision (`FLOAT16/FLOAT32`) training to speed up the training process.

We can use the `torch.cuda.amp` module to enable mixed precision training.

We can use the `torch.cuda.amp.autocast` context manager to automatically cast the inputs to the model to `FLOAT16`.

We can use the `torch.cuda.amp.GradScaler` to scale the loss to avoid underflow/overflow issues.

```python

def train_mixed_precision(model, optimizer, scheduler, trainloader, epochs, device):
    model.to(device)
    model.train()
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        accuracy_counter = metrics.MulticlassAccuracy().to(device)
        loss_counter = metrics.Mean().to(device)

        progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")
        for data, labels in progress_bar:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = nn.functional.cross_entropy(predictions, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            accuracy_counter.update(predictions, labels)
            loss_counter.update(loss, weight=data.size(0))

            progress_bar.set_postfix(
                loss=loss_counter.compute().item(),
                accuracy=accuracy_counter.compute().item()
            )
        scheduler.step()
```