# GPU training and model fine-tuning

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

import torch.nn as nn
from torch.utils.data import random_split, DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

## CUDA

PyTorch relies on CUDA for Nvidia GPUs to perform fast tensor operations.
FYI, there is also ROCm support for AMD GPUs, but I've never tried it, so I don't know how well it works.

In order to train your model on a Nvidia GPU, you need to:
- Install the required CUDA drivers (check here https://pytorch.org/ which ones are supported by the last stable Pytorch version)
- Install Pytorch with CUDA support (same link)
- Mount tensors to GPU using `tensor.to("cuda")` or `tensor.cuda()`

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# hyperparameters
batch_size = 16
lr = 1e-3
n_epochs = 100
train_val_split = [.8, .2]

In [None]:
datapath = 'data'

# load dataset
data_train = MNIST(
    root = datapath,
    train = True,                         
    transform = ToTensor(), 
)
data_test = MNIST(
    root = datapath, 
    train = False, 
    transform = ToTensor(),
)
data_train, data_val = random_split(data_train, train_val_split, generator=torch.Generator())

### Validation and hyperparameter tuning

Once we have assembled our architecture, there are still many aspects that we can fine-tune in order to make it work better.
For example, we can change the batch size, learning rate, number of training epochs, or other application-specific hyperparameters.

Ideally, we would like to choose the parameters that have the best generalization capability (e.g., we'd like to find a suitable number of epochs so that the model gets properly trained without overfitting).

However, we cannot use our test dataset for that. I know it is tempting to do so (and a lot of people out there are doing it). But if you choose the hyperparameters that yield the best performance on the test data, you're essentially using them to "train" your model.

REMEMBER: You should use your test dataset only to assess your final model. Never touch the test dataset before the model is ready-to-deploy.

So, how do we fine-tune our model to reach the best generalization capability?
The typical solution is to further split the training data in a training and validation dataset.
The validation set is used only to fine-tune hyperparameters.

In [None]:
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True, 
                            pin_memory=True, 
                            num_workers=2
                            )

val_loader = DataLoader(data_val, batch_size=32, shuffle=False, 
                            pin_memory=True, 
                            num_workers=2
                            )

test_loader = DataLoader(data_test, batch_size=32, shuffle=False, 
                            pin_memory=True, 
                            num_workers=2
                            )

In [None]:
print("Size of training dataset:", len(data_train))
print("Size of validation dataset:", len(data_val))
print("Size of test dataset:", len(data_test))

In [None]:
class SimpleMLP(nn.Module):

    """
    - input_shape: shape of a single input data point
    """
    def __init__(self, input_shape, n_classes):
        super(SimpleMLP, self).__init__()
        self.input_shape = np.asarray(input_shape)
        self.n_classes = n_classes
        self.seq_model = nn.Sequential(
            nn.Linear(self.input_shape.prod(), 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, self.n_classes)
        )

    def forward(self, x):
        x = x.view(-1, self.input_shape.prod())  # make the input of shape (batch_size, height*weight)
        logits = self.seq_model(x)
        return logits
    

In [None]:
model = SimpleMLP(input_shape=(1, 28, 28), n_classes=10).to(device)
print(model)

In [None]:
def model_accuracy(data_loader):
    n_total = 0
    n_correct = 0

    for x_batch, y_batch in data_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits_batch = model(x_batch)  # model's output scores
        n_total += len(y_batch)
        n_correct += sum(logits_batch.argmax(axis=-1) == y_batch).item()
    return n_correct / n_total

print(f"Train accuracy before training: {model_accuracy(train_loader):.4f}")
print(f"Test accuracy before training: {model_accuracy(test_loader):.4f}")

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

### Overfitting and early stopping

Overfitting is one of the main problems that you will encounter in your deep learning journey.
Essentially, we say that a model is "overfitting" when it achieve good performance on the training data, but it does not generalize on the validation and test data.

When this happens, there are many counter measures that you can take, for example:
- Decrease the number of parameters of your model
- Insert Dropout layers in your architecture
- Use regularization techniques (L1-L2 regularization, batch normalization, weight normalization)
- Reduce the number of epochs

In particular, a common approach for the last point is to use early stopping.
Just as the name says, early stopping consists in interrupting the training process if the model stops improving based on some criterion (typically if the validation loss or accuracy did not improve).

Parameters of the early stopping are typically:
- `patience`, i.e., how many epochs to wait before deciding to stop training
- `min_delta`, a minimum difference between the last recorded best value and the new best value to consider it an improvement (we don't use it in this example)

In [None]:
# Define the early stopping parameters
patience = 5
best_acc_val = 0.0
early_stop_counter = 0

In [None]:
accuracies_train = []
accuracies_val = []

for epoch in range(n_epochs):

    for i, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        logits_batch = model(x_batch)
        loss_batch = loss_fn(logits_batch, y_batch)
        loss_batch.backward()

        optimizer.step()

    # evaluate the model at the end of each epoch
    with torch.no_grad():
        acc_train = model_accuracy(train_loader)
        acc_val = model_accuracy(val_loader)

        print(f"[Epoch {epoch+1:03d}] train_acc: {acc_train:.3f}, val_acc: {acc_val:.3f}")

        accuracies_train.append(acc_train)
        accuracies_val.append(acc_val)

        if acc_val > best_acc_val:
            best_acc_val = acc_val
            early_stop_counter = 0
            # save model only if accuracy improved
            torch.save(model.state_dict(), f"saved_models/MLP_GPU_epoch{epoch+1:03d}.pt")
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            print(f"Validation accuracy has not improved in {patience} epochs: stop training.")
            break


In [None]:
plt.figure()
plt.plot(accuracies_train, '^-', label="Training")
plt.plot(accuracies_val, 's-', label="Validation")
plt.grid(linestyle=':')
plt.ylim([0.8, 1.05])
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.show()

In [None]:
print(f"Train accuracy after training: {model_accuracy(train_loader):.4f}")
print(f"Test accuracy after training: {model_accuracy(test_loader):.4f}")