# PyTorch Lightning Toy Example

Redo `mnist.ipynb` using PyTorch Lightning. Original notebook was pure pytorch with manual loops and so on.

Again, instrument with both TensorBoard and Weights & Biases.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from collections import OrderedDict

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor

from lightning.pytorch.utilities.model_summary import ModelSummary

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.loggers import TensorBoardLogger

### Lightning DataModule

In [20]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = Path('.') / 'data' / 'MNIST'
        self.transform = ToTensor()

    def prepare_data(self):  # this step done on single (main) CPU thread
        datasets.MNIST(str(self.data_dir), train=True, download=True, transform=self.transform)
        datasets.MNIST(str(self.data_dir), train=False, download=True, transform=self.transform)

    def setup(self, stage):  # this step run on every CPU
        if stage == "fit":
            train_data = datasets.MNIST(str(p), train=True, download=False, transform=ToTensor())
            train_data, val_data = random_split(train_data, lengths=[0.9, 0.1])
            self.train_data = train_data
            self.val_data = val_data
        if stage == "test":
            test_data = datasets.MNIST(str(p), train=False, download=False, transform=ToTensor())
            self.test_data = test_data

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, drop_last=True)

### PyTorch nn.Module

In [21]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # nn.Sequential accepts OrderedDict to give layers meaningful names
        self.cnn_stack = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3, 3), stride=2, padding='valid')),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding='same')),
            ('relu2', nn.ReLU()),
            ('conv3', nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=2, padding='valid')),
            ('relu3', nn.ReLU()),
            ('avgpool', nn.AdaptiveAvgPool2d(1)),
        ]))       
        self.fc_stack = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(32, 32)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(32, 10)),
        ]))

    def forward(self, x):
        x = self.cnn_stack(x)
        x = torch.squeeze(x)
        logits = self.fc_stack(x)

        return logits

### LightningModule

In [22]:
class LitCNN(L.LightningModule):
    def __init__(self, model):
        super().__init__()   
        self.model = model
        self.example_input_array = torch.Tensor(32, 1, 28, 28)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)  # calls forward
        loss = F.cross_entropy(logits, y)
        self.log("train/loss", loss)
        if self.global_step % 100 == 0:
            tb_logger = self.logger.experiment
            num_correct = (logits.detach().argmax(dim=1) == y).sum().type(torch.float).item()
            self.log("train/acc", num_correct/len(y))
            for n, t in self.named_parameters():
                tb_logger.add_histogram("param/" + n, t.detach(), global_step=self.global_step)
                if t.grad is not None:
                    tb_logger.add_histogram("grad/" + n, t.grad.detach(), global_step=self.global_step)      
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        val_loss = F.cross_entropy(logits, y)
        self.log("val_loss", val_loss, on_epoch=True)
        num_correct = (logits.detach().argmax(dim=1) == y).sum().type(torch.float).item()
        self.log("val_acc", num_correct / len(y))
        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
        return optimizer        

In [23]:
cnn = LitCNN(CNN())
mnist_data = MNISTDataModule(batch_size=64)
# print(ModelSummary(cnn, max_depth=-1))
tb_logger = TensorBoardLogger("tb_logs", log_graph=True)
# wb_logger = WandbLogger(log_model="all")
# wb_logger.watch(cnn, log="all")
trainer = L.Trainer(max_epochs=5, enable_progress_bar=False, logger=tb_logger, val_check_interval=0.2)
trainer.fit(cnn, mnist_data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


NameError: name 'p' is not defined

In [6]:
help(trainer.fit)

Help on method fit in module lightning.pytorch.trainer.trainer:

fit(model: 'pl.LightningModule', train_dataloaders: Union[Any, lightning.pytorch.core.datamodule.LightningDataModule, NoneType] = None, val_dataloaders: Optional[Any] = None, datamodule: Optional[lightning.pytorch.core.datamodule.LightningDataModule] = None, ckpt_path: Optional[str] = None) -> None method of lightning.pytorch.trainer.trainer.Trainer instance
    Runs the full optimization routine.
    
    Args:
        model: Model to fit.
    
        train_dataloaders: An iterable or collection of iterables specifying training samples.
            Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
            the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
    
        val_dataloaders: An iterable or collection of iterables specifying validation samples.
    
        datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that 

In [13]:
help(L.LightningDataModule)

Help on class LightningDataModule in module lightning.pytorch.core.datamodule:

class LightningDataModule(lightning.pytorch.core.hooks.DataHooks, lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin)
 |  LightningDataModule() -> None
 |  
 |  A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
 |  consistent data splits, data preparation and transforms across models.
 |  
 |  Example::
 |  
 |      import lightning as L
 |      import torch.utils.data as data
 |      from lightning.pytorch.demos.boring_classes import RandomDataset
 |  
 |      class MyDataModule(L.LightningDataModule):
 |          def prepare_data(self):
 |              # download, IO, etc. Useful with shared filesystems
 |              # only called on 1 GPU/TPU in distributed
 |              ...
 |  
 |          def setup(self, stage):
 |              # make assignments here (val/train/test split)
 |              # called on every process in DDP