# PyTorch Lightning with CIFAR10

Easier to managing experiments, and hardware settings

https://www.pytorchlightning.ai/

In [12]:
import os

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import torchmetrics

In [20]:
class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Construct Layers
        self.convs = nn.Sequential(
            # Conv1
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),
            #(B, 3, 32, 32) > (B, 8, 32, 32)
            nn.ReLU(),
            # (B, 8, 32, 32)
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (B, 8, 16, 16)
            # Conv2
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(16),
            # (B, 8, 16, 16) > (B, 16, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (B, 16, 14, 14) > (B, 16, 7, 7)
            # Conv3
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=1, padding=0),
            nn.BatchNorm2d(32),
            # (7+0-2)/1 + 1
            # (B, 16, 7, 7) > (B, 32, 6, 6)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (B, 32, 3, 3)
        )
        self.fc = nn.Sequential(
            nn.Linear(in_features=32*3*3, out_features=100),  # tensor size ?
            nn.ReLU(), #(B, 100)
            nn.Linear(100, 100), #(B, 100) > (B, 100) weight(100, 100)
            nn.ReLU(),  
            nn.Linear(100, 10), 
        )
        
    def forward(self, x):
        # forward propagation
        # Conv Layers
        x = self.convs(x)
        # resize to (batch_size, 32*3*3)
        x = x.view(x.size(0), -1)
        # FC Layers
        x = self.fc(x)
        return x

In [21]:
class PLNetwork(pl.LightningModule):
    def __init__(self, **kwargs) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = CNNModel()
        self.loss_function = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(num_classes=10)
        
    def forward(self, inputs, targets):
        outputs = self.model(inputs)
        loss = self.loss_function(outputs, targets)  # when calculate loss must need gold answer
        preds = outputs.argmax(1)
        self.accuracy(preds.view(-1), targets.view(-1))
        return loss
    
    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        loss = self(inputs, targets)
        self.log("tr_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=False)
        return {"loss": loss}

    def train_epoch_end(self, outputs):
        loss = 0
        for out in outputs:
            loss += out["loss"].detach()
        loss = loss / len(outputs)
        
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("train_accuracy", self.accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
    
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        loss = self(inputs, targets)
        
        return {"loss": loss}

    def validation_epoch_end(self, outputs):
        loss = 0
        for out in outputs:
            loss += out["loss"].detach()
        loss = loss / len(outputs)
        
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_accuracy", self.accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
    
    def create_dataloader(self, mode):  
        self.cifar10_mean = (0.4914, 0.4822, 0.4465)
        self.cifar10_std = (0.2470, 0.2435, 0.2616)        
        num_workers = 0 if os.name == "nt" else self.hparams.num_workers
        if mode == "train":
            shuffle = True
            batch_size = self.hparams.train_batch_size
            img_transformer = transforms.Compose([
                transforms.RandomHorizontalFlip(),  # random(with probability of 0.5) horizontal flip
                transforms.ToTensor(),  # change PIL Ojbect into Tensor (32, 32, 3) > (3, 32, 32)
                transforms.Normalize(mean=self.cifar10_mean, std=self.cifar10_std)  # normalizing with mean & std
            ])
            mode_train = True
        else:
            shuffle = False
            batch_size = self.hparams.eval_batch_size
            img_transformer = transforms.Compose([
                transforms.ToTensor(),  # change PIL Ojbect into Tensor (32, 32, 3) > (3, 32, 32)
                transforms.Normalize(mean=self.cifar10_mean, std=self.cifar10_std)  # normalizing with mean & std
            ])
            mode_train = False
            
        dataset = datasets.CIFAR10(
            root=self.hparams.data_path,
            train=mode_train, 
            transform=img_transformer,
            download=True
        )
        data_loader = torch.utils.data.DataLoader(
            batch_size=batch_size,
            dataset=dataset,
            shuffle=shuffle,
            num_workers=num_workers,
        )
        return data_loader

    def train_dataloader(self):
        return self.create_dataloader(mode="train")

    def val_dataloader(self):
        return self.create_dataloader(mode="eval")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        return optimizer

In [22]:
args_dict = dict(
    data_path = './data/cifar10/',
    # Dataloader
    train_batch_size = 256,
    eval_batch_size = 256,
    num_workers = 0,
    # Optimizer
    lr = 0.001,
    weight_decay = 1e-5,
)
# Checkpoint callback for model saving
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filename="{epoch:02d}-{val_accuracy:.4f}",
    monitor="val_accuracy",
    save_top_k=3,
    verbose=True,
    mode="max",
)
# Tensor Board Logger
tb_logger = pl.loggers.TensorBoardLogger(save_dir="./logs", default_hp_metric=False)
# Early Stop Callback
earlystop_callback = pl.callbacks.EarlyStopping(
    "val_accuracy",
    min_delta=0.01,
    patience=3,
    verbose=True,
    mode="max"
)
# Seed all include numpy, torch and torch.cuda
pl.seed_everything(884)

model = PLNetwork(**args_dict)

trainer = pl.Trainer(
    callbacks=[checkpoint_callback, earlystop_callback],
    max_epochs=20,
    deterministic=torch.cuda.is_available(),
    gpus = 1 if torch.cuda.is_available() else None,
    num_sanity_val_steps=0,
    logger=tb_logger,
    log_every_n_steps=20,
    profiler="simple"
)

trainer.fit(model)

Global seed set to 884
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | CNNModel         | 43.6 K
1 | loss_function | CrossEntropyLoss | 0     
2 | accuracy      | Accuracy         | 0     
---------------------------------------------------
43.6 K    Trainable params
0         Non-trainable params
43.6 K    Total params
0.174     Total estimated model params size (MB)


Files already downloaded and verified
Files already downloaded and verified


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved. New best score: 0.452
Epoch 0, global step 195: val_accuracy reached 0.45237 (best 0.45237), saving model to "./logs\default\version_3\checkpoints\epoch=00-val_accuracy=0.4524.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.113 >= min_delta = 0.01. New best score: 0.565
Epoch 1, global step 391: val_accuracy reached 0.56492 (best 0.56492), saving model to "./logs\default\version_3\checkpoints\epoch=01-val_accuracy=0.5649.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.040 >= min_delta = 0.01. New best score: 0.605
Epoch 2, global step 587: val_accuracy reached 0.60455 (best 0.60455), saving model to "./logs\default\version_3\checkpoints\epoch=02-val_accuracy=0.6046.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.026 >= min_delta = 0.01. New best score: 0.631
Epoch 3, global step 783: val_accuracy reached 0.63095 (best 0.63095), saving model to "./logs\default\version_3\checkpoints\epoch=03-val_accuracy=0.6309.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.022 >= min_delta = 0.01. New best score: 0.653
Epoch 4, global step 979: val_accuracy reached 0.65252 (best 0.65252), saving model to "./logs\default\version_3\checkpoints\epoch=04-val_accuracy=0.6525.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.015 >= min_delta = 0.01. New best score: 0.667
Epoch 5, global step 1175: val_accuracy reached 0.66743 (best 0.66743), saving model to "./logs\default\version_3\checkpoints\epoch=05-val_accuracy=0.6674.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.013 >= min_delta = 0.01. New best score: 0.680
Epoch 6, global step 1371: val_accuracy reached 0.68028 (best 0.68028), saving model to "./logs\default\version_3\checkpoints\epoch=06-val_accuracy=0.6803.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.012 >= min_delta = 0.01. New best score: 0.692
Epoch 7, global step 1567: val_accuracy reached 0.69190 (best 0.69190), saving model to "./logs\default\version_3\checkpoints\epoch=07-val_accuracy=0.6919.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.011 >= min_delta = 0.01. New best score: 0.703
Epoch 8, global step 1763: val_accuracy reached 0.70285 (best 0.70285), saving model to "./logs\default\version_3\checkpoints\epoch=08-val_accuracy=0.7028.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 1959: val_accuracy reached 0.71030 (best 0.71030), saving model to "./logs\default\version_3\checkpoints\epoch=09-val_accuracy=0.7103.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 2155: val_accuracy reached 0.71200 (best 0.71200), saving model to "./logs\default\version_3\checkpoints\epoch=10-val_accuracy=0.7120.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.018 >= min_delta = 0.01. New best score: 0.721
Epoch 11, global step 2351: val_accuracy reached 0.72133 (best 0.72133), saving model to "./logs\default\version_3\checkpoints\epoch=11-val_accuracy=0.7213.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 2547: val_accuracy reached 0.72277 (best 0.72277), saving model to "./logs\default\version_3\checkpoints\epoch=12-val_accuracy=0.7228.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.012 >= min_delta = 0.01. New best score: 0.733
Epoch 13, global step 2743: val_accuracy reached 0.73348 (best 0.73348), saving model to "./logs\default\version_3\checkpoints\epoch=13-val_accuracy=0.7335.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 2939: val_accuracy reached 0.73830 (best 0.73830), saving model to "./logs\default\version_3\checkpoints\epoch=14-val_accuracy=0.7383.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 3135: val_accuracy reached 0.74120 (best 0.74120), saving model to "./logs\default\version_3\checkpoints\epoch=15-val_accuracy=0.7412.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Monitored metric val_accuracy did not improve in the last 3 records. Best score: 0.733. Signaling Trainer to stop.
Epoch 16, global step 3331: val_accuracy reached 0.74345 (best 0.74345), saving model to "./logs\default\version_3\checkpoints\epoch=16-val_accuracy=0.7434.ckpt" as top 3
FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  244.41         	|  100 %          	|
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  14.285         	|17             	|  242.84         	|  99.361         	|
get_train_batch                    	|  0.049501       	|3332           	|  164.94

See TensorBoard after training

```
$ tensorboard --logdir ./logs/ --port 6006
```

visit `http://localhost:6006/` in Chrome