<a href="https://colab.research.google.com/github/TirendazAcademy/PyTorch-Lightning-Tutorials/blob/main/Lightning_with_Tensorboard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### What is Pytorch Lightning?
![figure.png](https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2F8qhjh%2Fbtr5eobWvx3%2FXslpIFC0apO8lmSCUe8VVK%2Fimg.png)
PyTorch Lightning is an open-source Python library that provides a high-level interface for PyTorch. While PyTorch alone is sufficient for easily creating various AI models, the code can become complex when experimenting under more advanced conditions such as using GPUs, TPUs, 16-bit precision, or distributed learning. To address this, PyTorch Lightning was developed as a project that abstracts the code, aiming to establish a unified coding style beyond just a framework.

```python
dataset = LightningDataset()
model = LightningModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=model, datamodule=dataset)
```

This tutorial is heavily inspired by great pytorch-lightning tutorials before, including:

* [Pytorch lightning tutorials](https://lightning.ai/docs/pytorch/stable/tutorials.html?utm_source=chatgpt.com)
* [Lightning examples](https://github.com/Lightning-AI/tutorials/tree/main/lightning_examples)
* [Why You Should Use PyTorch Lightning and How to Get Started](https://www.sabrepc.com/blog/Deep-Learning-and-AI/why-use-pytorch-lightning)
* [Beginner guide to pytorch-lightning](https://www.kaggle.com/code/shivanandmn/beginners-guide-to-pytorch-lightning/notebook)

And the documentations:
* [Pytorch lightning - Read the Docs](https://lightning.ai/docs/pytorch/LTS/)

### Import required libraries

In [None]:
!pip install lightning -q

In [None]:
import os
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric
from pytorch_lightning.callbacks import EarlyStopping, Callback
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip
import torchvision
from pytorch_lightning.loggers import TensorBoardLogger
import lightning as L

In [None]:
print("torch version:",torch.__version__)
print("pytorch ligthening version:",pl.__version__)

### Define a LightningModule
A LightningModule enables your PyTorch nn.Module to play together in complex ways inside the training_step (there is also an optional validation_step and test_step).

There are many reserved methods in the lighningmodules called hooks:

- ```configure_optimizers``` - this should return optimizer(Adam/SGD)
- ```training_step``` - training loop, takes batch and batch_idx as parameters
- ```validation_step```-validation loop, takes batch and batch_idx as parameters
- ```testing_step```- testing loop, takes batch and batch_idx as parameters


```python
class LightningModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        pass
  
    def configure_optimizers(self):
        pass
  
    def loss_fn(self, output, target):
        pass 
  
    def training_step(self):
        pass
  
    def validation_step(self):
        pass
```

In [None]:
class MLP(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1 * 28 * 28, 200)  # MNIST image size
        self.layer2 = nn.Linear(200, 200)
        self.layer3 = nn.Linear(200, 10)  # MNIST has 10 classes

    def forward(self, x):
        # 1. Implement forward pass
        ...

    def training_step(self, batch, batch_idx):
        # 2. Implement training step
        ...

    def configure_optimizers(self):
        # 3. Define optimizer
        ...

Define the validation_step and test_step methods for your MLP model in PyTorch Lightning:
- log "val_loss", "val_acc", "test_loss" and "test_acc"

In [None]:
class MLP(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1 * 28 * 28, 200)  # MNIST image size
        self.layer2 = nn.Linear(200, 200)
        self.layer3 = nn.Linear(200, 10)  # MNIST has 10 classes

    def forward(self, x):
        x = x.view(x.shape[0], -1)  # flatten, (B, 1*28*28)
        x = F.relu(self.layer1(x))  # (B, 200)
        x = F.relu(self.layer2(x))  # (B, 200)
        x = self.layer3(x)  # (B, 10)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)  # Forward pass
        loss = nn.functional.cross_entropy(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

    # 1. Implement the validation_step
    def validation_step(self, batch, batch_idx):
        ...

    # 2. Implement the test_step
    def test_step(self, batch, batch_idx):
        ...

### Define a dataset
Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

Hooks:
- ```train_dataloader()```
- ```val_dataloader()```
- ```test_dataloader()```
Above methods in lightning datamodule are dataloaders

- prepare_data(): Download and tokenize or do preprocessing on complete dataset, because this is called on single gpu if your using mulitple gpu, data here is not shared accross gpus.
- setup(): splitting or transformations etc. setup takes stage argument None by default or fit or test for training and testing respectively.

```python
class LightningDataset(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
  
    def prepare_data(self):
        pass
  
    def setup(self, stage=None):
        pass
  
    def train_dataloader(self):
        pass
  
    def val_dataloader(self):
        pass
  
    def test_dataloader(self):
        pass
```

In [None]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        # 1. prepare data (download MNIST)
        ...

    def setup(self, stage):
        # 2. Preprocess data. Define train, val & test datasets.
        # Image augmentation: RandomVerticalFlip, RandomHorizontalFlip, ToTensor
        entire_dataset = ...
        # There is no validation dataset in MNIST, so we split the training dataset
        self.train_ds, self.val_ds = ...
        self.test_ds = ...

    # 3-1. Implement train_dataloader
    def train_dataloader(self):
        ...

    # 3-2. Implement val_dataloader
    def val_dataloader(self):
        ...

    # 3-3. Implement test_dataloader
    def test_dataloader(self):
        ...

### Callbacks
PyTorch Lightning의 Callback 함수는 모델 학습 과정에서 특정 이벤트가 발생할 때 실행되는 사용자 정의 기능을 추가할 수 있도록 도와주는 강력한 도구입니다. 이를 통해 모델 학습, 검증, 예측 등의 과정에서 다양한 작업을 자동화할 수 있습니다.
- Early Stopping (학습 조기 종료)
- Model Checkpointing (최적의 모델 저장)
- Logging & Visualization (TensorBoard, WandB 등의 로깅)
- Learning Rate Scheduling (학습률 조정)
- Custom Actions (모델 성능 평가, 추가적인 데이터 로깅 등)

```
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[
        early_stopping,
        checkpoint_callback,
        PrintLearningRateCallback()
    ]
)
```

1. [Built-in-callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#built-in-callbacks)
2. [Callback API](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#callback-api)

#### 1️⃣ 간단한 Callback 예제

다음은 학습 시작과 종료 시 로그를 출력하는 간단한 Callback 입니다.

In [None]:
class MyPrintingCallback(Callback):
    def __init__(self):
        super().__init__()

    def on_train_start(self, trainer, pl_module):
        print("Starting to train!")

    def on_train_end(self, trainer, pl_module):
        print("Training is done.")

#### 2️⃣ Early Stopping Callback

PyTorch Lightning에서는 EarlyStopping 콜백을 제공하여 모델의 성능이 개선되지 않으면 학습을 자동으로 종료할 수 있습니다.

In [None]:
from pytorch_lightning.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor="val_loss",  # 모니터링할 값
    patience=3,          # 개선이 없을 경우 종료할 epoch 수
    verbose=True,
    mode="min"           # 최소값이 가장 좋은 경우로 설정 (loss는 낮을수록 좋음)
)

#### 3️⃣ Model Checkpointing Callback

최고 성능을 보이는 모델을 저장하는 Callback도 사용할 수 있습니다.

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",       # 기준이 되는 metric
    dirpath="checkpoints/",   # 저장 경로
    filename="best-checkpoint",  # 파일 이름
    save_top_k=1,             # 가장 좋은 k개의 모델만 저장
    mode="min",               # loss가 낮을수록 좋은 모델로 판단
    verbose=True
)

#### 4️⃣ Practice: Custom Callback (사용자 정의)

커스텀 Callback을 직접 만들 수도 있습니다.

먼저, 모든 10번째 배치마다 현재 학습률을 출력하는 Callback 을 만들어보겠습니다. (Hint: on_train_batch_end)

In [None]:
class PrintLearningRateCallback(pl.Callback):
    # 1. 모든 10번째 배치마다 학습률 출력
    ...

#### 5️⃣ Practice: Gradient Clipping (자동 그래디언트 클리핑)

PyTorch Lightning에서는 trainer에 gradient_clip_val을 설정하면 그래디언트 클리핑을 할 수 있지만,
커스텀 Callback을 사용하면 특정 조건에서만 적용할 수도 있습니다.

1. [torch.clamp](https://pytorch.org/docs/stable/generated/torch.clamp.html#torch.clamp)
2. [on_after_backward](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#on-after-backward)

In [None]:
class GradientClippingCallback(pl.Callback):
    def __init__(self, clip_value=0.5):
        self.clip_value = clip_value

    # 1. 역전파 후 그래디언트 클리핑
    ...

#### 6️⃣ Practice: Epoch 별 학습 시간 측정 Callback
각 epoch이 끝날 때마다 소요 시간을 측정하여 출력하는 Callback입니다.

1. [on_train_epoch_start](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#on-train-epoch-start)
2. [on_train_epoch_end](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#on-train-epoch-end)

In [None]:
import time

class TimerCallback(pl.Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        self.epoch_start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        elapsed_time = time.time() - self.epoch_start_time
        print(f"Epoch {trainer.current_epoch} 소요 시간: {elapsed_time:.2f}초")

# Setting the hyperparameters
PyTorch Lightning에서는 여러 가지 하이퍼파라미터를 제공합니다.
- [Trainer flags](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-flags)

In [None]:
model = MLP()
dm = MnistDataModule(
    data_dir="dataset/",
    batch_size=100,
    num_workers=4,
)
trainer = pl.Trainer(
    logger=TensorBoardLogger("tb_logs", name="mnist_model_v0"),
    accelerator="gpu",
    devices=[0],
    min_epochs=1,
    max_epochs=2,
    callbacks=[PrintLearningRateCallback(), GradientClippingCallback(), TimerCallback()],
)

# Training the model

In [None]:
trainer.fit(model, dm)
trainer.validate(model, dm)
trainer.test(model, dm)

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs