Why PyTorch Lightning?
PyTorch Lightning is a lightweight PyTorch wrapper that simplifies the process of training and organizing complex deep learning models. 
* Pomotes a modular and organized code structure by separating concerns into separate components
* Streamlines training process (get rid of a lot of boilerplate)
* Reproducibility and scalability
* Integrates seamlessly with other popular libraries and tools in the research ecosystem 

Big picture:
* Dataset
* Build a model
* Define loss_fn & optimizer
* Define trainer (model produces prediction -> compute loss (pred vs. label) -> backprop)
* Define validation data set (train/validation/test)
* Run trainer & test

In [48]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

In [49]:
# get the train/valid data set:
train_ds = MNIST(root="data-pl", train=True, download=True, transform=ToTensor())
valid_ds = MNIST(root="data-pl", train=False, download=True, transform=ToTensor())

# dataloader (dl):
bs = 64
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data-pl/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 18548412.98it/s]


Extracting data-pl/MNIST/raw/train-images-idx3-ubyte.gz to data-pl/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data-pl/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28263111.02it/s]


Extracting data-pl/MNIST/raw/train-labels-idx1-ubyte.gz to data-pl/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data-pl/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13961569.54it/s]


Extracting data-pl/MNIST/raw/t10k-images-idx3-ubyte.gz to data-pl/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data-pl/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7324309.41it/s]


Extracting data-pl/MNIST/raw/t10k-labels-idx1-ubyte.gz to data-pl/MNIST/raw



In [50]:
import torchmetrics

# build the model:
class MNIST_Model(pl.LightningModule):  # pl.LightningModule is nn.Module with a few extra feature
  def __init__(self, lr=0.5):
    super().__init__()
    self.lr = lr
    self.lin = nn.Linear(784, 10)  # 784 (28 * 28)

    # metrics:
    self.train_accuracy = torchmetrics.Accuracy()
    self.valid_accuracy = torchmetrics.Accuracy()

  
  # forward: input -> prediction
  def forward(self, xb):
    """
    xb = torch.tensor([
      [[1, 2, 3, 4],
      [5, 6, 7, 8],
      [9, 10, 11, 12]],
      
      [[13, 14, 15, 16],
      [17, 18, 19, 20],
      [21, 22, 23, 24]]
    ])

    # flatten the tensor along the second and all other dimensions: 
    xb = xb.flatten(1, -1)

    print(xb)
    tensor([
      [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
      [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
    ])
    """
    xb = xb.flatten(1, -1)
    return  self.lin(xb)
  
  # train:
  def training_step(self, batch):
    loss = self.shared_step(batch, train=True)
    return loss
  
  # validate:
  def validation_step(self, batch):
    self.shared_step(batch, train=False)
  
  # shared steps in training & validation:
  def shared_step(self, batch):
    xb, yb = batch
    pred = self(xb)
    loss = F.cross_entropy(pred, yb)

    # logging:
    if (train):
      self.train_accuracy(pred.softmax(dim=-1), yb)
      self.log("train_accuracy:", self.train_accuracy, on_step=True, on_epoch=False, prog_bar=True)  # progress bar
    else:  # validation
      self.valid_accuracy(pred.softmax(dim=-1), yb)
      self.log("valid_accuracy:", self.valid_accuracy, on_step=True, on_epoch=True, prog_bar=True)  # progress bar

    return loss
  
  # def test_step(...)
  
  # optimizer:
  def configure_optimizers(self):
    return optim.SGD(self.parameters(), lr=self.lr)

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

# logger:
tb_logger = TensorBoardLogger("tb_logs")

In [None]:
# init a model:
mnist_model = MNIST_Model()

# init trainer:
trainer = pl.Trainer(max_epochs=2, logger=tb_logger) # tb (tensorboard)

# train the model:
trainer.fit(mnist_model, train_dl)
# trainer.fit(mnist_model, train_dl, valid_dl)

# # run test (optional):
# trainer.test()