# Example 09: PyTorch-Lightning

## 事前準備

In [1]:
import torch

# GPUが使えるか確認してデバイスを設定
# NOTE: `x = x.to(device) ` とすることで対象のデバイスに切り替え可能
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
import optuna
import torch
import torchinfo
import pytorch_lightning as pl
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

In [6]:
class Net(nn.Module):
    def __init__(self, n_classes: int,
                       dropout_prob: float = 0.5,
                       activation_func: nn.Module = nn.ReLU()):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 5)         # 入力チャネル、出力チャネル、フィルタ数
        self.active = activation_func
        self.pool = nn.MaxPool2d(2, 2)          # 領域のサイズ、領域の間隔
        self.conv2 = nn.Conv2d(8, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 256)
        self.dropout = nn.Dropout(dropout_prob)          # ドロップアウト率
        self.fc2 = nn.Linear(256, n_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.active(self.conv1(x))
        x = self.pool(x)
        x = self.active(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.active(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [13]:
class DictLogger(pl.loggers.Logger):
    """PyTorch Lightning `dict` logger."""
 
    def __init__(self, version):
        super(DictLogger, self).__init__()
        self.metrics = []
        self._version = version
 
    def log_metrics(self, metric, step_num=None):
        self.metrics.append(metric)
 
    @property
    def version(self):
        return self._version

In [None]:
class LightningNet(pl.LightningModule):
 
    def __init__(self, trial):
        super(LightningNet, self).__init__()
        # TODO
        #self.trial = trial  # for optuna
        #self.model = Net(self.trial)
 
    def forward(self, data):
        return self.model(data)
 
    def training_step(self, batch, batch_nb):
        data, target = batch
        output = self.forward(data)
        loss = F.nll_loss(output, target)
        return {'loss': loss}
 
    def validation_step(self, batch, batch_nb):
        data, target = batch
        output = self.forward(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / data.size(0)
        return {'validation_accuracy': accuracy}
 
    def validation_end(self, outputs):
        accuracy = sum(x['validation_accuracy'] for x in outputs) / len(outputs)
        # Pass the accuracy to the `DictLogger` via the `'log'` key.
        return {'log': {'accuracy': accuracy}}
 
    def configure_optimizers(self):
        # Generate the optimizers.
        return Get_Optimizer(self.trial, self.model)
 
    @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()), batch_size=BATCHSIZE, shuffle=True)
 
    @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            datasets.MNIST(DIR, train=False, download=True, transform=transforms.ToTensor()), batch_size=BATCHSIZE, shuffle=False)