# pytorch lightning & wandb
---
## image classification task
> cifar10 dataset

In [1]:
# !pip install pytorch_lightning
# !pip install wandb

In [2]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torchmetrics

import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

## DataModule 
- 데이터 준비
- 디스크에 저장
- 데이터 전처리
- dataloader 내부 래핑

### init
- 하이퍼파라미터 설정
- 데이터 변환 파이프라인 정의

### prepare_data
- 디스크에 저장하거나 분산설정에서 단일 GPU에서만 수행하는 작업
- 상태 할당은 x (self.tmp = ...)
- 데이터셋 다운로드 

### setup_data
- 데이터 로드
- split train, valid, test => stage("fit", "test")

### dataloader
- setup에서 준비한 데이터세트를 dataloader로 래핑

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
        
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

## Custom Callback

### init

### callback Hooks
- callback 호출 위치

### ImagePredictionLogger
- custom callback
- 일부 이미지 샘플에 대한 모델의 예측 시각화에 사용

In [4]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                            for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                  preds[:self.num_samples], 
                                                  val_labels[:self.num_samples])]
        })

## LightningModule
- Computations => __init__
- train loop => training_step
- validation loop => validation_step
- test loop => test_step
- optimizers => configure_optimizers

### init => Computations
- model architecture
- pytorch model class의 init과 비슷
- save_hyperparameters   
=> init 안에있는 모든 값을 체크포인트에 저장

### forward
- 추론작업 정의하는데 사용

### training_step
- args: bacth, batch_idx
- 한 배치에 대한 학습과정
- .log(on_epoch=True) => 에포크 메트릭 계산
- .log(on_step=True) => 배치 메트릭 계산 (on_step=True가 default)

### validation_step
- 한 배치에 대한 검증과정

### test_step
- trainer.test() 사용시 수행

### configure_optimizers
- 옵티마이저 및 학습 속도 스케줄러 정의

In [8]:
class ClassifyModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate

        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)

        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)

    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.rand(batch_size, *shape)

        output_feat = self._forward_features(input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.log_softmax(self.fc3(x), dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True) # on_step=False, on_epoch=True
        self.log('val_acc', acc, prog_bar=True) # on_step=False, on_epoch=True
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True) # on_step=False, on_epoch=True
        self.log('test_acc', acc, prog_bar=True) # on_step=False, on_epoch=True
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

## Train and Evaluate
### Trainer automates:
- Epoch and batch iteration
- calling optimizer.step(), backward, zero_grad()
- calling .eval(), enabling/disabling grads
- saving and loading weights
- weights & biases logging
- multi gpu training
- tpu support
- 16-bit training support 

1. 데이터 파이프라인 (DataModule) 초기화   
=> ImagePredictionLogger에 sample data를 주기위해 prepare_data, setup 수동 호출

In [9]:
dm = DataModule(batch_size=32)
dm.prepare_data()
dm.setup()

val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Files already downloaded and verified
Files already downloaded and verified


(torch.Size([32, 3, 32, 32]), torch.Size([32]))

2. model & logger 초기화 그리고 학습

In [13]:
model = ClassifyModel((3, 32, 32), dm.num_classes)
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

trainer = pl.Trainer(max_epochs=50, 
                     accelerator='gpu', 
                     logger=wandb_logger,
                     callbacks=[ImagePredictionLogger(val_samples), 
                                EarlyStopping(monitor='val_loss'), 
                                ModelCheckpoint()])

trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
wandb.finish()

  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type               | Params
------------------------------------------------
0 | conv1    | Conv2d             | 896   
1 | conv2    | Conv2d             | 9.2 K 
2 | conv3    | Conv2d             | 18.5 K
3 | conv4    | Conv2d             | 36.9 K
4 | pool1    | MaxPool2d          | 0     
5 | pool2    | MaxPool2d          | 0     
6 | fc1      | Linear             | 819 K 
7 | fc2      | Linear             | 65.7 K
8 | fc3      | Linear             | 1.3 K 
9 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
952 K     Trainable params
0         Non-trainable params
952 K     Total params
3.809     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn(


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./wandb-lightning/ep69jiy7/checkpoints/epoch=12-step=18291-v1.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./wandb-lightning/ep69jiy7/checkpoints/epoch=12-step=18291-v1.ckpt


Testing: 0it [00:00, ?it/s]

0,1
epoch,▁▂▃▃▅▅▆▇▁▂▃▄▅▆▆▇█▂▃▃▄▅▆▇▂▂▃▅▅▆▇▇▂▃▃▅▅▆▇█
test_acc,▁
test_loss,▁
train_acc_epoch,▁▃▄▅▆▆▇▇▁▃▄▅▆▇▇██▃▄▅▅▆▇▇▁▃▅▅▆▆▇▇▃▄▅▅▆▇▇█
train_acc_step,▂▄▃▆▆▅▇▇▁▅▆▅▇███▁▅▅▅▇▆█▇▃▄▅▇▇▇█▃▄▆▅▆▆▇▇█
train_loss_epoch,█▆▅▄▄▃▂▂█▆▅▄▃▃▂▁▁▆▅▄▄▃▃▂█▆▄▄▃▃▂▂▆▅▄▄▃▂▂▁
train_loss_step,▆▅▆▄▃▃▁▁▇▄▃▄▂▂▂▁█▄▄▄▂▂▂▂▅▅▄▃▃▂▂▆▅▄▅▂▃▂▁▁
trainer/global_step,▁▂▃▄▅▆▆▇▁▂▃▄▅▆▇▇▁▂▃▄▄▅▆▇▂▃▃▄▅▆▇▁▂▃▄▅▅▆▇█
val_acc,▁▃▅▆▇▇██▁▃▅▆▇▇███▃▄▆▆▇▇▇▁▄▆▇▇▇▇▇▄▄▆▇▇███
val_loss,█▆▄▃▂▂▁▁█▆▃▃▁▂▁▂▂▆▅▃▃▂▂▂█▅▃▂▁▁▂▂▆▄▂▂▁▁▂▃

0,1
epoch,13.0
test_acc,0.7309
test_loss,0.98429
train_acc_epoch,0.89622
train_acc_step,0.84375
train_loss_epoch,0.29303
train_loss_step,0.43859
trainer/global_step,18291.0
val_acc,0.7316
val_loss,0.96498
