The torch community has built several libraries and APIs on top of PyTorch with various additional features. Some popular ones include fastai, Catalyst, PyTorch Lightning, and PyTorch Ignite. We will explore PyTorch Lightning in this notebook, which removes the manual aspects of backprop, mixed-precision, multi-GPU support, etc.

In [1]:
import pytorch_lightning as pl

import torch
import torch.nn as nn

# Auto installed with Lightning
from torchmetrics import Accuracy

from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST

from torchvision import transforms

In [2]:
class MLP(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()

        self.train_acc = Accuracy(task='multiclass', num_classes=10)
        self.valid_acc = Accuracy(task='multiclass', num_classes=10)
        self.test_acc = Accuracy(task='multiclass', num_classes=10)

        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        all_layers = [nn.Flatten()]

        for hidden in hidden_units:
            layer = nn.Linear(input_size, hidden)
            all_layers.append(layer)
            all_layers.append(nn.ReLU())
            input_size = hidden
        
        all_layers.append(nn.Linear(hidden_units[-1], 10))
        self.model = nn.Sequential(*all_layers)
    
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch

        # self here refers to the MLP instance, which runs the model on the input
        loss = nn.functional.cross_entropy(self(x), y)

        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        
        # log stores the loss for analysis later
        # We don't log the accuracy until the end of an epoch
        self.log("train_loss", loss, prog_bar=True)

        return loss
    
    def on_epoch_end(self, outs):
        self.log("train_acc", self.train_acc.compute())
    
    def validation_step(self, batch, batch_idx):
        x, y = batch

        # self here refers to the MLP instance, which runs the model on the input
        loss = nn.functional.cross_entropy(self(x), y)

        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        
        self.log("valid_loss", loss, prog_bar=True)
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch

        # self here refers to the MLP instance, which runs the model on the input
        loss = nn.functional.cross_entropy(self(x), y)

        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)

        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [3]:
# Loading the data:
# Option 1: Make the dataset part of the model (above)
# Option 2: Set up the usual DataLoaders and feed them into the Lightning's fit function (need to create a Trainer instance)
# Option
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
        MNIST(root=self.data_path, download=False)
    
    def setup(self, stage=None):
        # stage can take 'fit', 'validate', 'test', 'predict'
        mnist_all = MNIST(root=self.data_path, train=True, transform=self.transform, download=False)

        self.train, self.val = random_split(mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(42))

        self.test = MNIST(root=self.data_path, train=False, transform=self.transform, download=False)
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)

torch.manual_seed(42)
mnist_dm = MNISTDataModule()

In [4]:
mnistclassifier = MLP()

if torch.cuda.is_available(): # if you GPUs are available
    trainer = pl.Trainer(max_epochs=10, accelerator='gpu')
else:
    trainer = pl.Trainer(max_epochs=10)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total esti

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 9: 100%|██████████| 860/860 [00:18<00:00, 45.74it/s, v_num=2, train_loss=0.0178, valid_loss=0.152, valid_acc=0.940] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:18<00:00, 45.74it/s, v_num=2, train_loss=0.0178, valid_loss=0.152, valid_acc=0.940]


In [6]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 3736), started 0:00:35 ago. (Use '!kill 3736' to kill it.)

In [12]:
# We can resume training from a checkpoint in this model folder
if torch.cuda.is_available(): # if you have GPUs
    trainer = pl.Trainer(max_epochs=15, accelerator='gpu')
else:
    trainer = pl.Trainer(max_epochs=15)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='./lightning_logs/version_2/checkpoints/epoch=9-step=8600.ckpt')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Restoring states from the checkpoint path at ./lightning_logs/version_2/checkpoints/epoch=9-step=8600.ckpt
c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:362: The dirpath has changed from 'd:\\Primary Storage\\Programming\\Machine_Learning_Sebastian_Raschka\\Neural_Networks\\pytroch_fundamentals\\lightning_logs\\version_2\\checkpoints' to 'd:\\Primary Storage\\Programming\\Machine_Learning_Sebastian_Raschka\\Neural_Networks\\pytroch_fundamentals\\lightning_logs\\version_5\\checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | Multicl

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                            

c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 14: 100%|██████████| 860/860 [00:18<00:00, 46.11it/s, v_num=5, train_loss=0.0129, valid_loss=0.146, valid_acc=0.946] 

`Trainer.fit` stopped: `max_epochs=15` reached.


Epoch 14: 100%|██████████| 860/860 [00:18<00:00, 46.07it/s, v_num=5, train_loss=0.0129, valid_loss=0.146, valid_acc=0.946]


In [13]:
# LEt's check if the additional training was worth it
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 3736), started 0:13:56 ago. (Use '!kill 3736' to kill it.)

In [14]:
# The gain in an additional 5 epochs was under 1% - whether this is worth it is completely up to the individual and the case at hand
trainer.test(model=mnistclassifier, datamodule=mnist_dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\sadit\miniconda3\envs\dataexercises\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 198.41it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9600574970245361
        test_loss           0.12282264232635498
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.12282264232635498, 'test_acc': 0.9600574970245361}]