In [None]:
import torch
import numpy as np
import random
import os

import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import CSVLogger

In [None]:
from modules.lightningCNN import ResNet_pl
from modules.dataModule import CIFAR10_pl

In [None]:
# 乱数固定
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(SEED)

# Trainerの準備
 trainerとはpytorch lightningにおいて学習・テスト，ログの記録，モデルの保存などを自動で行ってくれるクラスのことです．
また，GPUの管理を一括で行なってくれるため，GPUを気にせずに学習モデルやデータローダーなどを作成できます．(わざわざ.to(device)などの記述をしなくてもいい)

In [None]:
# csv logger
csv_logger = CSVLogger('logs', name='cifar10')

In [None]:
# checkpoint callback
checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    dirpath='best_models',
    filename='cifar10-{epoch:02d}-{val_loss:.2f}',
)

In [None]:
# early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
)

In [None]:
# trainer
DEVICES = [0] # 使用するGPUの番号をリスト形式で指定

trainer = Trainer(
    accelerator='cuda',
    devices=DEVICES,
    max_epochs=10,
    callbacks=[checkpoint, early_stopping],
    logger=csv_logger,
)

# 学習

In [None]:
# data module
dataset = CIFAR10_pl(batch_size=512, download=True)

In [None]:
# model
model = ResNet_pl(num_class=10, batch_size=512, lr=0.001)

In [None]:
# do train !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
trainer.fit(model, dataset)

In [None]:
# print best model path
print(checkpoint.best_model_path)
best_model_path = checkpoint.best_model_path

# テスト

In [None]:
model = ResNet_pl.load_from_checkpoint(best_model_path)
trainer.test(model, dataset)