## Notebook: Exploration on More Usage of PyTorch Lightning

This notebook is a continuation of the previous notebook on PyTorch Lightning. In this notebook, we will explore more on the usage of PyTorch Lightning. We will cover the following topics:
- Training Monitoring (Visualization using TensorBoard)
- Callbacks (Lightning built-in callbacks and customized callbacks)

In [None]:
import os

# Set the environment variables for the Slack API
os.environ["SLACK_APP_TOKEN"] = ""
os.environ["SLACK_BOT_TOKEN"] = ""
os.environ["SLACK_DEFAULT_CHANNEL_ID"] = ""

In [None]:
import lightning
import numpy as np
import torch
import torchvision

import demo.lightning.models
import demo.lightning.callbacks.slack
import demo.slack.client

torch.set_float32_matmul_precision("high")

## Initialize Dataset: Initialize the CIFAR10 dataset for image classification task

In this section, we initialize the CIFAR10 dataset and dataloaders. The dataset is prepared for image classification task. We will use this dataset to demonstrate the usage of PyTorch Lightning. 

#### Notes for CIFAR10 dataset:
- CIFAR10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class.
- There are 50,000 training images and 10,000 test images.
- The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. "Automobile" includes sedans, SUVs, things of that sort. "Truck" includes only big trucks. Neither includes pickup trucks.
- More details can be found at: https://www.cs.toronto.edu/~kriz/cifar.html

In [None]:
DATAPATH = "./data"

class_labels = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
denorm_fn = torchvision.transforms.Normalize((-1, -1, -1), (2, 2, 2))

class DataModuleCIFAR10(lightning.LightningDataModule):
    """PyTorch Lightning data module"""

    def __init__(self, batch_size: int):
        """Data module for the simple linear regression dataset"""

        super().__init__()
        self.train_dataset = torchvision.datasets.CIFAR10(
            root=DATAPATH, train=True, download=True, transform=transform
        )
        self.val_dataset = torchvision.datasets.CIFAR10(
            root=DATAPATH, train=False, download=True, transform=transform
        )
        self.batch_size = batch_size

    def train_dataloader(self):
        """Training data loader"""

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
        )

    def val_dataloader(self):
        """Validation data loader"""

        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
        )

In [None]:
data_module = DataModuleCIFAR10(batch_size=512)

## PyTorch Lightning: Train a ResNet

In this section, we will finetune a ResNet18 model using PyTorch Lightning. We will demonstrate the usage of PyTorch Lightning for training and validation. We will also show how to monitor the training process using TensorBoard.

In [None]:
lightning.pytorch.seed_everything(42, workers=True)

hyperparameters = {
    "optimizer": {
        "lr": 1e-7,
        "weight_decay": 0.01,
    },
    "scheduler": {
        "step_size": 5,
        "gamma": 0.1,
    },
}

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=10,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
)

trainer.fit(model=resnet, datamodule=data_module)

## PyTorch Lightning: Hyperparameter Search

In this section, we will demonstrate how to perform hyperparameter search using PyTorch Lightning. We will randomly generate some sets of hyperparameters and train the model with these hyperparameters. Then we visualize the results using TensorBoard.

#### Note:
- This section will run for around 5 minutes (reference gpu: RTX 4060 ti).

In [None]:
search_space = {
    "optimizer": {
        # [low, high]
        "lr": [1e-7, 1e-5],
        "weight_decay": [0.01, 0.1],
    },
    "scheduler": {
        # [low, high]
        "step_size": [5, 10],
        "gamma": [0.1, 0.5],
    },
}

n_samples = 5

In [None]:
def generate_hyperparameters():
    """Generate hyperparameters"""

    hyperparameters = {}
    for key, value in search_space.items():
        hyperparameters[key] = {}
        for subkey, subvalue in value.items():
            hyperparameters[key][subkey] = np.random.uniform(subvalue[0], subvalue[1])
    return hyperparameters

In [None]:
for _ in range(n_samples):
    hyperparameters = generate_hyperparameters()
    resnet = demo.lightning.models.ResNet18(
        num_classes=len(class_labels),
        class_labels=class_labels,
        denorm_fn=denorm_fn,
        hyperparameters=hyperparameters,
    )
    trainer = lightning.pytorch.Trainer(
        max_epochs=10,
        log_every_n_steps=1,
        check_val_every_n_epoch=1,
    )

    trainer.fit(model=resnet, datamodule=data_module)

## PyTorch Lightning: Built-in Callbacks

In this section, we will demonstrate the usage of PyTorch Lightning built-in callbacks. We will use the following callbacks:
- ModelCheckpoint
- EarlyStopping

In [None]:
callbacks = [
    lightning.pytorch.callbacks.EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=1,
        verbose=True,
    ),
    lightning.pytorch.callbacks.ModelCheckpoint(
        dirpath="./models",
        filename="resnet18_cifar10_best_{epoch:02d}",
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        verbose=True,
    ),
]

In [None]:
hyperparameters = {
    "optimizer": {
        "lr": 1e-1,
        "weight_decay": 0.01,
    },
    "scheduler": {
        "step_size": 5,
        "gamma": 0.1,
    },
}

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=10,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
)
trainer.fit(model=resnet, datamodule=data_module)

## PyTorch Lightning: Customized Callbacks

In this section, we will create a customized callback for PyTorch Lightning. With the implemented callback, we will be able to monitor the training process on Slack and terminate the training process remotely. This showcases the flexibility of PyTorch Lightning in customizing the training process.

In [None]:
slack_client = demo.slack.client.SlackClient()
callbacks = [
    demo.lightning.callbacks.slack.MonitorTrainingOnSlack(
        slack_client=slack_client, log_every_n_epoch=1
    )
]

In [None]:
hyperparameters = generate_hyperparameters()

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=50,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
)
trainer.fit(model=resnet, datamodule=data_module)