## More Exploration on Usage of PyTorch Lightning

This notebook is a continuation of the previous notebook on PyTorch Lightning. In this notebook, we will explore more on the usage of PyTorch Lightning. We will cover the following topics:
- Logging (scalar / visualization)
- Callbacks (Native / Customized)

In [None]:
import os

os.environ["SLACK_BOT_TOKEN"] = ""
os.environ["SLACK_APP_TOKEN"] = ""
os.environ["SLACK_DEFAULT_CHANNEL_ID"] = ""

In [None]:
import lightning
import numpy as np
import torch
import torchvision

import demo.lightning.models
import demo.lightning.callbacks.slack
import demo.slack.client

torch.set_float32_matmul_precision("high")

## Initialize Dataset

In [None]:
DATAPATH = "./data"

class_labels = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
denorm_fn = torchvision.transforms.Normalize((-1, -1, -1), (2, 2, 2))

class DataModuleCIFAR10(lightning.LightningDataModule):
    """PyTorch Lightning data module"""

    def __init__(self, batch_size: int):
        """Data module for the simple linear regression dataset"""

        super().__init__()
        self.train_dataset = torchvision.datasets.CIFAR10(
            root=DATAPATH, train=True, download=True, transform=transform
        )
        self.val_dataset = torchvision.datasets.CIFAR10(
            root=DATAPATH, train=False, download=True, transform=transform
        )
        self.batch_size = batch_size

    def train_dataloader(self):
        """Training data loader"""

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
        )

    def val_dataloader(self):
        """Validation data loader"""

        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
        )

In [None]:
data_module = DataModuleCIFAR10(batch_size=512)

## Single Training

In [None]:
lightning.pytorch.seed_everything(42, workers=True)

hyperparameters = {
    "optimizer": {
        "lr": 1e-7,
        "weight_decay": 0.01,
    },
    "scheduler": {
        "step_size": 5,
        "gamma": 0.1,
    },
}

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=10,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
)

trainer.fit(model=resnet, datamodule=data_module)

## Hyperparameter Search

Note: this part will run for around 5 minutes (reference gpu: RTX 4060 ti)

In [None]:
search_space = {
    "optimizer": {
        # [low, high]
        "lr": [1e-7, 1e-5],
        "weight_decay": [0.01, 0.1],
    },
    "scheduler": {
        # [low, high]
        "step_size": [5, 10],
        "gamma": [0.1, 0.5],
    },
}

n_samples = 5

In [None]:
def generate_hyperparameters():
    """Generate hyperparameters"""

    hyperparameters = {}
    for key, value in search_space.items():
        hyperparameters[key] = {}
        for subkey, subvalue in value.items():
            hyperparameters[key][subkey] = np.random.uniform(subvalue[0], subvalue[1])
    return hyperparameters

In [None]:
for _ in range(n_samples):
    hyperparameters = generate_hyperparameters()
    resnet = demo.lightning.models.ResNet18(
        num_classes=len(class_labels),
        class_labels=class_labels,
        denorm_fn=denorm_fn,
        hyperparameters=hyperparameters,
    )
    trainer = lightning.pytorch.Trainer(
        max_epochs=10,
        log_every_n_steps=1,
        check_val_every_n_epoch=1,
    )

    trainer.fit(model=resnet, datamodule=data_module)

## Callbacks

In [None]:
callbacks = [
    lightning.pytorch.callbacks.EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=1,
        verbose=True,
    ),
    lightning.pytorch.callbacks.ModelCheckpoint(
        dirpath="./models",
        filename="resnet18_cifar10_best_{epoch:02d}",
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        verbose=True,
    ),
]

In [None]:
hyperparameters = generate_hyperparameters()

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=10,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
)
trainer.fit(model=resnet, datamodule=data_module)

## Customized Callbacks

In [None]:
slack_client = demo.slack.client.SlackClient()
callbacks = [
    demo.lightning.callbacks.slack.MonitorTrainingOnSlack(
        slack_client=slack_client, log_every_n_epoch=1
    )
]

In [None]:
hyperparameters = generate_hyperparameters()

resnet = demo.lightning.models.ResNet18(
    num_classes=len(class_labels),
    class_labels=class_labels,
    denorm_fn=denorm_fn,
    hyperparameters=hyperparameters,
)
trainer = lightning.pytorch.Trainer(
    max_epochs=50,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
)
trainer.fit(model=resnet, datamodule=data_module)