# Example 09: PyTorch-Lightning

## 事前準備

In [1]:
import torch

# GPUが使えるか確認してデバイスを設定
# NOTE: `x = x.to(device) ` とすることで対象のデバイスに切り替え可能
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
# https://qiita.com/north_redwing/items/1e153139125d37829d2d
from typing import Callable

import random
import numpy as np
import torch
import pytorch_lightning as pl

def fixing_seed(seed: int=42) -> tuple[torch.Generator, Callable]:
    # NOTE: 戻り値はDataLoaderのシード固定に使用する。　(ex) loader = DataLoader(..., worker_init_fn=seed_worker, generator=generator)
    #       また、plのTrainer(deterministic=True)とする必要がある。
    # Python のシード固定
    random.seed(seed)
    # Numpy のシード固定
    np.random.seed(seed)
    # PyTorch のシード固定
    torch.manual_seed(seed)
    # CUDA の再現性確保の設定
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # PyTorch Lightning のシード固定
    pl.seed_everything(seed, workers=True)

    # --------------------------------------------------
    # DataLoader 各ワーカー用
    # --------------------------------------------------
    # DataLoader 各ワーカー用の初期化関数
    def seed_worker(worker_id: int) -> None:
        # 各 worker で Python / NumPy も固定
        wseed = torch.initial_seed() % 2**32
        np.random.seed(wseed)
        random.seed(wseed)

    # DataLoader の乱数源
    generator = torch.Generator()
    generator.manual_seed(seed)

    return generator, seed_worker

In [3]:
generator, seed_worker_fn = fixing_seed()

Seed set to 42


In [4]:
NUM_WORKERS = 9

## CNN

In [5]:
from typing import Any

import numpy as np
import optuna
import torch
import torchinfo
import pytorch_lightning as pl
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

### 自動チューニング設定(optuna)

今回は以下のパラメータをチューニング対象とする。

- 最適化アルゴリズム
- 活性化関数
- ドロップアウト率

#### Optimizer

In [6]:
def get_adam_optimizer(trial: optuna.trial.Trial, model: nn.Module) -> optim.Optimizer:
    lr = trial.suggest_float('adam_lr', 1e-5, 1e-1, log=True)
    weight_decay = trial.suggest_float('adam_weight_decay', 1e-10, 1e-3)
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)
    return optimizer

def get_momentum_sgd_optimizer(trial: optuna.trial.Trial, model: nn.Module) -> optim.Optimizer:
    lr = trial.suggest_float('momentum_sgd_lr', 1e-5, 1e-1, log=True)
    weight_decay = trial.suggest_float('momentum_sgd_weight_decay', 1e-10, 1e-3, log=True)
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=weight_decay)
    return optimizer

def get_rms_prob_optimizer(trial: optuna.trial.Trial, model: nn.Module) -> optim.Optimizer:
    lr = trial.suggest_float('rms_prob_lr', 1e-5, 1e-1, log=True)
    optimizer = optim.RMSprop(model.parameters(), lr=lr)
    return optimizer

In [7]:
def get_optimizer(trial: optuna.trial.Trial, model: nn.Module) -> optim.Optimizer:
    optimizer_names = ['Adam', 'MomentumSGD', 'rmsprop']
    optimizer_name = trial.suggest_categorical('optimizer', optimizer_names)
    
    if optimizer_name == 'Adam': 
        optimizer = get_adam_optimizer(trial, model)
    elif optimizer_name == 'MomentumSGD':
        optimizer = get_momentum_sgd_optimizer(trial, model)
    else:
        optimizer = get_rms_prob_optimizer(trial, model)
    
    return optimizer

#### 活性化関数

In [8]:
def get_activation(trial: optuna.trial.Trial) -> nn.Module:
    activation_names = ['ReLU', 'Tanh']
    activation_name = trial.suggest_categorical('activation', activation_names)
    
    if activation_name == 'ReLU':
        activation = nn.ReLU()
    else:
        activation = nn.Tanh()
    return activation

#### ハイパーパラメータ

In [9]:
def get_hyperparam(trial: optuna.trial.Trial) -> dict[str, Any]:
    return {
        'dropout_prob': trial.suggest_float('dropout_prob', 0.2, 0.8, step=0.1),
    }

### モデル構築

In [10]:
class Net(nn.Module):
    def __init__(self, n_classes: int,
                       dropout_prob: float = 0.5,
                       activation_func: nn.Module = nn.ReLU()):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 5)         # 入力チャネル、出力チャネル、フィルタ数
        self.active = activation_func
        self.pool = nn.MaxPool2d(2, 2)          # 領域のサイズ、領域の間隔
        self.conv2 = nn.Conv2d(8, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 256)
        self.dropout = nn.Dropout(dropout_prob)          # ドロップアウト率
        self.fc2 = nn.Linear(256, n_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.active(self.conv1(x))
        x = self.pool(x)
        x = self.active(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.active(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

### PyTorch Lightning モデル用意

#### ロガー

In [11]:
class DictLogger(pl.loggers.Logger):
    """PyTorch Lightning `dict` logger."""
 
    def __init__(self, version):
        super(DictLogger, self).__init__()
        self.metrics = []
        self._version = version
 
    def log_metrics(self, metric, step_num=None):
        self.metrics.append(metric)
 
    @property
    def version(self):
        return self._version

#### データモジュール

In [12]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, root_dir: str="../cache/data", batch_size: int=128, valid_size: float=0.2, num_workers: int=0):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.valid_size = valid_size
        self.num_workers = num_workers

        self.train_transform = transforms.Compose([
            transforms.RandomAffine((-30, 30), scale=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
        ])
        self.valid_transform = transforms.Compose([
            transforms.RandomAffine((-30, 30), scale=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
        ])

    def _prepare_train_valid_index(self, data_length: int, valid_size: float = 0.2) -> tuple[SubsetRandomSampler, SubsetRandomSampler]:
        indices = list(range(data_length))
        split = int(np.floor(valid_size * data_length))
        np.random.shuffle(indices)
        
        train_idx = indices[split:]
        valid_idx = indices[:split]

        train_sampler = SubsetRandomSampler(train_idx)
        validation_sampler = SubsetRandomSampler(valid_idx)
        return (train_sampler, validation_sampler) 
    
    def prepare_data(self) -> None:
        train_dataset = CIFAR10(self.root_dir, train=True, download=True)
        CIFAR10(self.root_dir, train=False, download=True)
        self.train_sampler, self.valid_sampler = self._prepare_train_valid_index(len(train_dataset), self.valid_size)

    def setup(self, stage: str) -> None:
        if stage == "fit":
            self.train_dataset = CIFAR10(self.root_dir, train=True, transform=self.train_transform)
            self.valid_dataset = CIFAR10(self.root_dir, train=True, transform=self.valid_transform)
        elif stage == "test":
            self.test_dataset = CIFAR10(self.root_dir, train=False, transform=self.test_transform)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, sampler=self.train_sampler)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, sampler=self.valid_sampler)
    
    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, worker_init_fn=seed_worker_fn, generator=generator)

In [13]:
def check_datamodule():
    dm = CIFAR10DataModule()
    dm.prepare_data()
    for stage in ("fit", "test"):
        dm.setup(stage)
        print(f"///// {stage} /////")
        print(dm)
        print()

check_datamodule()

///// fit /////
{Train dataloader: size=50000}
{Validation dataloader: size=50000}
{Test dataloader: None}
{Predict dataloader: None}

///// test /////
{Train dataloader: size=50000}
{Validation dataloader: size=50000}
{Test dataloader: size=10000}
{Predict dataloader: None}



#### ネットワークモジュール

In [14]:
def calc_accuracy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    pred = output.argmax(dim=1, keepdim=True)
    correct = pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / output.size(0)
    return accuracy


class LightningNet(pl.LightningModule):
 
    def __init__(self, trial: optuna.trial.Trial, n_class: int, batch_size: int = 128):
        super(LightningNet, self).__init__()
        self.trial = trial
        self.batch_size = batch_size

        hyperparam = get_hyperparam(self.trial)
        activation = get_activation(self.trial)        
        self.model = Net(n_class, hyperparam['dropout_prob'], activation)
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.model(data)
 
    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        data, target = batch
        output = self.forward(data)
        loss = criterion(output, target)
        accuracy = calc_accuracy(output, target)
        return {
            'train_loss': loss,
            'train_accuracy': accuracy
        }
 
    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        data, target = batch
        output = self.forward(data)
        loss = criterion(output, target) 
        accuracy = calc_accuracy(output, target)
        return {
            'valid_loss': loss,
            'valid_accuracy': accuracy
        }

    """
    def validation_end(self, outputs) -> dict[str, Any]:
        accuracy = sum(x['valid_accuracy'] for x in outputs) / len(outputs)
        # Pass the accuracy to the `DictLogger` via the `'log'` key.
        return {'log': {'accuracy': accuracy}}
    """
    
    def configure_optimizers(self):
        # Generate the optimizers.
        return get_optimizer(self.trial, self.model)

### TODO

In [15]:
#data_module = CIFAR10DataModule()
#data_module.prepare_data()
#data_module.setup(stage="fit")

#logger = DictLogger(trial.number)

#trainer = pl.Trainer(
#    gpus=1,
#    max_epochs=30,
#    deterministic=True)
trainer = pl.Trainer(
    max_epochs=30,
    deterministic=True)

#model = LightningNet(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
#trainer.fit(model, data_module)

#data_module.setup(stage="test")
#trainer.test(datamodule=data_module)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [16]:
### TEST

In [17]:
class LightningNet_DUMMY(pl.LightningModule):
 
    def __init__(self, n_class: int = 10, batch_size: int = 128):
        super(LightningNet_DUMMY, self).__init__()
        self.batch_size = batch_size
        self.model = Net(n_class)
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.model(data)
 
    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        data, target = batch
        output = self.forward(data)
        loss = self.criterion(output, target)
        accuracy = calc_accuracy(output, target)
        return {
            'loss': loss,
            'accuracy': accuracy
        }
 
    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        data, target = batch
        output = self.forward(data)
        loss = self.criterion(output, target) 
        accuracy = calc_accuracy(output, target)
        return {
            'valid_loss': loss,
            'valid_accuracy': accuracy
        }

    # TODO: バッチごとに必要？
    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict[str, torch.Tensor]:
        data, target = batch
        output = self.forward(data)
        loss = self.criterion(output, target) 
        accuracy = calc_accuracy(output, target)
        return {
            'test_loss': loss,
            'test_accuracy': accuracy
        }

    """
    def validation_end(self, outputs) -> dict[str, Any]:
        accuracy = sum(x['valid_accuracy'] for x in outputs) / len(outputs)
        # Pass the accuracy to the `DictLogger` via the `'log'` key.
        return {'log': {'accuracy': accuracy}}
    """
    
    def configure_optimizers(self):
        # Generate the optimizers.
        return optim.Adam(self.model.parameters())

In [18]:
dm = CIFAR10DataModule(num_workers=NUM_WORKERS)
dm.prepare_data()
dm.setup(stage="fit")

model = LightningNet_DUMMY()
trainer.fit(model, dm)
trainer.validate(datamodule=dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | Net              | 109 K  | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
109 K     Trainable params
0         Non-trainable params
109 K     Total params
0.436     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
Restoring states from the checkpoint path at /example-pytorch/examples/lightning_logs/version_4/checkpoints/epoch=29-step=9390.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /example-pytorch/examples/lightning_logs/version_4/checkpoints/epoch=29-step=9390.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /example-pytorch/examples/lightning_logs/version_4/checkpoints/epoch=29-step=9390.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /example-pytorch/examples/lightning_logs/version_4/checkpoints/epoch=29-step=9390.ckpt
/opt/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:484: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing: |          | 0/? [00:00<?, ?it/s]

[{}]