# pytorch lightning & wandb
---
## image classification task
> cifar10 dataset

In [7]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger

import wandb

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torch.utils.data import DataLoader

## DataModule 
- 데이터 준비
- 디스크에 저장
- 데이터 전처리
- dataloader 내부 래핑

### init
- 하이퍼파라미터 설정
- 데이터 변환 파이프라인 정의

### prepare_data
- 디스크에 저장하거나 분산설정에서 단일 GPU에서만 수행하는 작업
- 상태 할당은 x (self.tmp = ...)
- 데이터셋 다운로드 

### setup_data
- 데이터 로드
- split train, valid, test => stage("fit", "test")

### dataloader
- setup에서 준비한 데이터세트를 dataloader로 래핑

In [5]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
        
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

## Custom Callback

### init

### callback Hooks
- callback 호출 위치

### ImagePredictionLogger
- custom callback
- 일부 이미지 샘플에 대한 모델의 예측 시각화에 사용

In [8]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                            for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                  preds[:self.num_samples], 
                                                  val_labels[:self.num_samples])]
        })

## LightningModule
- Computations => __init__
- train loop => training_step
- validation loop => validation_step
- test loop => test_step
- optimizers => configure_optimizers