# CIS6930 Week 12: Pytorch Lightning Quick Start

---

Preparation: Go to `Runtime > Change runtime type` and choose `GPU` for the hardware accelerator.



In [None]:
gpu_info = !nvidia-smi -L
gpu_info = "\n".join(gpu_info)
if gpu_info.find("failed") >= 0:
    print("Not connected to a GPU")
else:
    print(gpu_info)

In [None]:
!pip install pytorch-lightning 

## Lightning Modules 

A `LightningModule` organizes your PyTorch code into 5 sections
+ Computations (`init`)
+ Train loop (`training_step`)
+ Validation loop (`validation_step`)
+ Test loop (`test_step`)
+ Optimizers (`configure_optimizers`)

See [the official tutorial](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html)

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
                nn.Linear(28 * 28, 64),
                nn.ReLU(),
                nn.Linear(64, 3))
        self.decoder = nn.Sequential(
                nn.Linear(3, 64),
                nn.ReLU(),
                nn.Linear(64, 28 * 28))

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1) #  (B, 28*28)
        z = self.encoder(x)    
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)  # (B, 28*28)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)

## Run experiments with `Trainer`

PyTorch Lightning offers a highly customizable `Trainer`.

[The official tutorial videos](https://www.pytorchlightning.ai/tutorials
) explain trainer features. It will be 30 minutes in total. Highly recommended. 


In [None]:
# Launch tensorboard
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# model
model = LitAutoEncoder()

# early stopping
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
early_stopping = EarlyStopping(
    monitor="val_loss",
    patience=3,
    strict=False,
    verbose=False,
    mode="min")

# TensorBoard Logger
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs", name="ae_mnist")

# training
trainer = pl.Trainer(callbacks=[early_stopping],
                     max_epochs=50,
                     auto_scale_batch_size=True,
                     auto_lr_find=True,
                     gpus=1,
                     precision=16,
                     limit_train_batches=0.5,
                     logger=logger)

trainer.fit(model, train_loader, val_loader)