In [100]:
import wandb
import torch
import torchvision
import torchmetrics
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data.dataset import random_split

### WandB 로그인

In [101]:
wandb.login()               

True

### 데이터 모듈 생성 (pl.LightningDataModule 사용)

In [102]:
class CIFAR10DataModule(pl.LightningDataModule):      ## 데이터 모듈은 반드시, LightningDataModule 을 상속토록 함.
    def __init__(self, batch_size, data_path = './data'):
        super().__init__()
        self.data_path = data_path           ## 데이터 셋 디렉토리
        self.batch_size = batch_size       ## 1 batch = batch_size 개 만큼의 데이터 셋으로 구성 

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.dims = (3, 32, 32)  ## CIFAR10 데이터 셋의 Shape (32*32*3)
    
    ## prepare_data : CIFAR10 데이터 셋 다운로드
    def prepare_data(self):
        CIFAR10(self.data_path, train=True, download=True)  ## Train 데이터 셋 다운로드
        CIFAR10(self.data_path, train=False, download=True) ## Test 데이터 셋 다운로드
    
    ## setup : 다운로드된 데이터 셋들에 Transform 적용하고,, Train, Validation, Test 데이터 셋으로 분리 
    def setup(self):
        cifar_full = CIFAR10(self.data_path, train=True, transform=self.transform)
        self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000]) # Train, Validation 데이터 셋 분리
        self.cifar_test = CIFAR10(self.data_path, train=False, transform=self.transform) # Test 데이터 셋 
            
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)            

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)
    

### LightningModule 통한 System (Model 과 유사) 정의

In [103]:
class ToyModel(pl.LightningModule):   ## LightningModule 을 반드시 상속
    def __init__(self, lr=0.001):
        super().__init__()
        
        # 체크포인트까지의 값들을 저장하기 위해 사용(log hyperparameters)
        self.save_hyperparameters()
        
        ## Learning Rate 설정
        self.lr = lr
        
        self.conv1 = nn.Conv2d(3, 32, 3)
        ## 3  (in_channels) : Channel 수 3개 (R,G,B)
        ## 32 (out_channels) : Output 의 Channel 수 32 개 (따라서, 3*3*3 의 이미지 필터가 총 32개 존재)
        ## 3  (kernel_size) : Kernel 사이즈 3*3
        ## 1  (Stride) : 커널 윈도우 Stride 1
        
        self.pool1 = nn.MaxPool2d(2)
        ## 2 (kernel size)
        ## 2 (Stride)
        
        
        self.conv2 = nn.Conv2d(32, 32, 3)
        ## 32  (in_channels) : Channel 수 32개 
        ## 32 (out_channels) : Output 의 Channel 수 32 개 (따라서, 32*3*3 의 이미지 필터가 총 32개 존재)
        ## 3  (kernel_size) : Kernel 사이즈 3*3
        ## 1  (Stride) : 커널 윈도우 Stride 1

        
        self.conv3 = nn.Conv2d(32, 64, 3)
        ## 32  (in_channels) : Channel 수 32개 
        ## 64 (out_channels) : Output 의 Channel 수 64 개 (따라서, 32*3*3 의 이미지 필터가 총 64개 존재)
        ## 3 (kernel_size) : Kernel 사이즈 (3*3)
        ## 1 (Stride)
        
        
        self.conv4 = nn.Conv2d(64, 64, 3)
        ## 64 (in_channels) : Channel 수 64개
        ## 64 (out_channels) : Output 의 Channel 수 64 개 (따라서, 64*3*3 의 이미지 필터가 총 64개 존재)
        ## 3 (kernel_size)
        ## 1 (Stride)
        
        self.pool2 = nn.MaxPool2d(2)
        ## 2 (kernel size)
        ## 2 (Stride)
        
        self.fc1 = nn.Linear(5 * 5 * 64, 512)                    ##
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)
        
        self.accuracy = torchmetrics.Accuracy()


        
    # _forward_features : Conv2d + MaxPoll2d 계층 통과 직후의 Tensor 를 반환
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))              ## 첫번째 Conv2d 계층 통과 -> ReLU   
                                               ## (output shape (H,W,C) :  30 * 30 * 32)
        x = self.pool1(F.relu(self.conv2(x)))  ## 두번째 Conv2d 계층 통과 -> ReLU -> 첫번째 MaxPool2d 
                                               ## (output shape (H,W,C) :  14 * 14 * 32)
        x = F.relu(self.conv3(x))              ## 세번째 Conv2d 계층 통과 -> ReLU 
                                               ## (output shape (H,W,C) :  12 * 12 * 64)
        x = self.pool2(F.relu(self.conv4(x)))  ## 마지막 Conv2d 계층 통과 -> ReLU -> 두번째 MaxPool2d
                                               ## (output shape (H,W,C) :  5 * 5 * 64) 
        return x
    
    def forward(self, x):
        x = self._forward_features(x)          ## Conv2d + MaxPool2d 계층 통과
        x = x.view(x.size(0), -1)              ## Tensor 선형화
        x = F.relu(self.fc1(x))                ## 첫번째 Affine 계층 통과 -> ReLU 
        x = F.relu(self.fc2(x))                ## 두번째 Affine 계층 통과 -> ReLU
        x = F.log_softmax(self.fc3(x), dim=1)  ## 세번째 Affine 계층 통과 -> SoftMax
       
        return x        
    
    ##
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    ##
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    ##
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    ## Optimizer 설정 https://neelesh609.github.io/cifar10/ 링크 참고하여, Adam 을 통한 Back-Prop 을 했을 때 성능이 가장 높았음을 참고
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
## https://github.com/rubentea16/pl-mnist/blob/master/model.py 참고

### 데이터셋 준비/로드

In [104]:
data_set = CIFAR10DataModule(batch_size=32)       ## CIFAR-10 데이터셋 설정 (1 Batch Size, Directory, Class 수, 1 Data Sample Shape)
data_set.prepare_data()                           ## CIFAR-10 데이터셋 다운로드
data_set.setup()                                  ## Train, Test, Validation 데이터 셋으로 분리

val_samples = next(iter(data_set.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Files already downloaded and verified
Files already downloaded and verified


(torch.Size([32, 3, 32, 32]), torch.Size([32]))

In [105]:
PROJ_NAME = '2017125033_GeonWooBaek_pytorch_lightning_Cifar10'

## Lightning Model 초기화 
model = ToyModel(data_set.size())

# Wandb Logger 초기화
wandb_logger = WandbLogger(project=PROJ_NAME, job_type='train')

# Initialize Callbacks
early_stop = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpt = pl.callbacks.ModelCheckpoint()

# Trainer 초기화 (Epoch 10 회, 1개의 GPU 사용(2개 사용가능한데, Jupyter 에서는 DDP 오류),)
trainer = pl.Trainer(max_epochs=10, gpus=1, logger=wandb_logger,callbacks=[early_stop,checkpt])

# 모델 학습 시작
trainer.fit(model, data_set)
# 모델 테스트 시작 
trainer.test(dataloaders=data_set.test_dataloader())

# Close wandb run
wandb.finish()


## https://wandb.ai/wandb_fc/korean/reports/Weights-Biases-Pytorch-Lightning---VmlldzozNzAxOTg 코드 참고
## https://wikidocs.net/157552 코드 참고

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



  | Name     | Type      | Params
---------------------------------------
0 | conv1    | Conv2d    | 896   
1 | pool1    | MaxPool2d | 0     
2 | conv2    | Conv2d    | 9.2 K 
3 | conv3    | Conv2d    | 18.5 K
4 | conv4    | Conv2d    | 36.9 K
5 | pool2    | MaxPool2d | 0     
6 | fc1      | Linear    | 819 K 
7 | fc2      | Linear    | 65.7 K
8 | fc3      | Linear    | 1.3 K 
9 | accuracy | Accuracy  | 0     
---------------------------------------
952 K     Trainable params
0         Non-trainable params
952 K     Total params
3.809     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /home/jovyan/2017125033/2017125033_GeonWooBaek_pytorch_lightning_Cifar10/3jz852bg/checkpoints/epoch=7-step=11255.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from checkpoint at /home/jovyan/2017125033/2017125033_GeonWooBaek_pytorch_lightning_Cifar10/3jz852bg/checkpoints/epoch=7-step=11255.ckpt


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.7293999791145325, 'test_loss': 0.9915676712989807}
--------------------------------------------------------------------------------


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█
test_acc,▁
test_loss,▁
train_acc_epoch,▁▄▅▆▇▇██
train_acc_step,▂▃▁▁▄▃▄▃▅▅▅▄▅▆▆▅▆▅▆▇▆▆▆▆▆▅▆▇▆▆▇▇▆▅▅▇█▇█▆
train_loss_epoch,█▅▄▃▃▂▁▁
train_loss_step,█▆▇▇▆█▅▇▄▄▄▅▅▃▃▅▃▄▃▃▃▃▂▂▃▄▄▂▃▃▃▃▃▄▄▂▁▂▂▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,▁▅▇▇████
val_loss,█▃▂▁▁▁▃▅

0,1
epoch,8.0
test_acc,0.7294
test_loss,0.99157
train_acc_epoch,0.89184
train_acc_step,0.9375
train_loss_epoch,0.30422
train_loss_step,0.22703
trainer/global_step,11256.0
val_acc,0.7298
val_loss,1.00187
