# In-class exercise 8: Converting your PyTorch code to PyTorch Lightning

Based on [Lightning in 15 minutes](https://lightning.ai/docs/pytorch/stable/starter/introduction.html) tutorial.

In this tutorial, we will convert the code from the previous tutorial to PyTorch Lightning. This will allow us to reduce the amount of boilerplate code we need to write, and also make it easier to train our model on multiple GPUs or even TPUs.

PyTorch Lightning is a lightweight **wrapper** for **organizing** your PyTorch code. It's **not a high-level framework**, so you still have to write PyTorch code, but it handles a lot of the details for you. It's especially useful for **standardizing** training loops, logging metrics, and saving checkpoints.

First, we install PyTorch Lightning and import the relevant classes and functions.

In [None]:
import copy
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from typeguard import typechecked

And we report here some code from the previous tutorial for reference purposes.

First: we re-define the model.

In [None]:
class CNN:
    # TODO
    pass

Then, we re-define the dataset and dataloaders.

In [None]:
def get_datasets():
    # TODO
    pass

In [None]:
def instantiate_dataloaders():
    # TODO
    pass

Let's see if everything works as expected.

In [None]:
model = None()  # TODO
print(model)
train_dataset, val_dataset, test_dataset = get_datasets()  # TODO
dataloaders = instantiate_dataloaders()  # TODO
print(next(iter(dataloaders["train"]))[0].shape)

Let's finally re-define the training loop.

In [None]:
def run_epoch():
    # TODO
    pass

## Hyperparameters

In [None]:
# training
EPOCHS = 5
BATCH_SIZE = 64
PATIENCE = 3

# data
NUM_WORKERS = 3

# optimizer
LEARNING_RATE = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

# model
NUM_LAYERS = 3
NUM_CHANNELS = 32
NUM_CLASSES = 10

# reproducibility
SEED = 42

## Step 1: replace `nn.Module` with `LightningModule`

- 1.1: Model architecture goes in the `__init__` method
- 1.2: Prediction/inference logic goes in the `forward` hook
- 1.3: Optimizers go in the `configure_optimizers` hook
- 1.4: Training logic goes in the `training_step` hook
- 1.5: Validation logic goes in the `validation_step` hook
- 1.6: Test logic goes in the `test_step` hook
- 1.7: Remove any `cuda()` or `to(device)` calls
- 1.8: Instantiate the `LightningModule`

In [None]:
class CNNLit(pl.LightningModule):
    # TODO
    pass

In [None]:
# Step 1.8: Create an instance of the `LitModule` class
lit_module = CNNLit()  # TODO

In [None]:
lit_module

## Step 2: replace the training loop with a `Trainer` instance

### Step 2.1: Training loop

Once the `LightningModule` is defined, we can train it using a `Trainer`.

- 1: Instantiate the `Trainer`
- 2: Call `trainer.fit(model, train_dataloader, val_dataloader)` to train the model

In [None]:
trainer = None  # TODO

# Call `trainer.fit()` to train the model
# TODO

### Step 2.2: Test

In [None]:
# Call `trainer.test()` to test the model
# TODO

## Step 3: replace the dataset and dataloaders with `LightningDataModule`

- 1: Move the dataset and dataloaders into a `LightningDataModule`
- 2: Instantiate the `LightningDataModule`
- 3: Pass the `LightningDataModule` to the `Trainer`

A `LightningDataModule` encapsulates the five steps involved in data processing in PyTorch:
- 2.1: Download / tokenize / process.
- 2.2: Clean and (maybe) save to disk.
- 2.3: Load inside Dataset.
- 2.4: Apply transforms (rotate, tokenize, etc…).
- 2.5: Wrap inside a DataLoader.

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    # TODO
    pass

In [None]:
datamodule = MNISTDataModule()  # TODO

In [None]:
model = CNNLit()  # TODO

In [None]:
trainer = pl.Trainer()  # TODO

# Fit
# TODO

# Test
# TODO

## Logging

So far, we were only able to log the loss. However, Lightning allows us to have much more control over logging. For example, we can log the loss and accuracy after each epoch, and also log the loss and accuracy after each batch. We can even log images, audio, text, and arbitrary objects.

We will be using [TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/index.html) to compute metrics. TorchMetrics is a collection of metrics for PyTorch. It allows us to avoid writing boilerplate code for computing metrics like accuracy, precision, recall, etc.

In [None]:
from torchmetrics import Accuracy


class CNNLit(pl.LightningModule):
    # TODO
    pass

In [None]:
pl.seed_everything(SEED)  # reproducibility

model = CNNLit()  # TODO
datamodule = MNISTDataModule()  # TODO

trainer = pl.Trainer()  # TODO

# Fit
# TODO

# Test
# TODO

## Visualizing metrics

We can visualize the metrics logged by Lightning using [TensorBoard](https://www.tensorflow.org/tensorboard) or [Weights & Biases](https://wandb.ai/site). We will be using TensorBoard in this tutorial. To install TensorBoard, you can use `pip install tensorboard`. To start TensorBoard, you can use the following command:

```bash
tensorboard --logdir lightning_logs/
```

In [None]:
class CNNLit(pl.LightningModule):
    # TODO
    pass

In [None]:
pl.seed_everything(SEED)  # reproducibility

model = CNNLit()  # TODO
datamodule = MNISTDataModule()  # TODO

logger = None  # TODO

trainer = pl.Trainer()  # TODO

# Fit
# TODO

# Test
# TODO

## Checkpointing

Lightning automatically saves checkpoints of your model at the end of every epoch. If training is interrupted, you can resume from the last saved checkpoint.

In [None]:
# Retrieve the best model
# TODO

# Test the best model
# TODO