# In-class exercise 8: Converting your PyTorch code to PyTorch Lightning

Based on [Lightning in 15 minutes](https://lightning.ai/docs/pytorch/stable/starter/introduction.html) tutorial.

In this tutorial, we will convert the code from the previous tutorial to PyTorch Lightning. This will allow us to reduce the amount of boilerplate code we need to write, and also make it easier to train our model on multiple GPUs or even TPUs.

PyTorch Lightning is a lightweight **wrapper** for **organizing** your PyTorch code. It's **not a high-level framework**, so you still have to write PyTorch code, but it handles a lot of the details for you. It's especially useful for **standardizing** training loops, logging metrics, and saving checkpoints.

First, we install PyTorch Lightning and import the relevant classes and functions.

In [1]:
import copy
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from typeguard import typechecked

And we report here some code from the previous tutorial for reference purposes.

First: we re-define the model.

In [2]:
@typechecked
class CNN(nn.Module):
    """Convolutional neural network model."""

    def __init__(self, num_layers: int, num_channels: int, num_classes: int) -> None:
        """Constructor method for CNN.

        Args:
            num_layers: the number of layers
            num_channels: the number of channels
            num_classes: the number of classes
        """
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = 1 if i == 0 else num_channels
            out_channels = num_channels if i < num_layers - 1 else num_classes
            self.layers.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            self.layers.append(nn.ReLU())
        self.layers.append(nn.AdaptiveAvgPool2d(1))

        self.init_weights()

    def init_weights(self) -> None:
        """Initialize the parameters.

        The weight is initialized using Xavier uniform initialization and the bias is initialized to zero.
        """
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: the input tensor

        Returns:
            the logits
        """
        for layer in self.layers:
            x = layer(x)
        return x

Then, we re-define the dataset and dataloaders.

In [3]:
def get_datasets(
    data_dir: str, train_size: int
) -> Tuple[
    torch.utils.data.Dataset, torch.utils.data.Dataset, torch.utils.data.Dataset
]:
    """Get the datasets.

    Args:
        data_dir: the directory to store the data
        train_size: the size of the training set

    Returns:
        the training set, validation set, and test set
    """
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    dev_set = MNIST(data_dir, train=True, download=True, transform=transform)
    train_set, val_set = random_split(dev_set, [train_size, len(dev_set) - train_size])
    test_set = MNIST(data_dir, train=False, download=True, transform=transform)
    return train_set, val_set, test_set

In [4]:
def instantiate_dataloaders(
    train_dataset: torch.utils.data.Dataset,
    val_dataset: torch.utils.data.Dataset,
    test_dataset: torch.utils.data.Dataset,
    batch_size: int,
    num_workers: int,
    seed: int,
) -> Dict[str, torch.utils.data.DataLoader]:
    """Instantiate dataloaders.

    Args:
        train_dataset: the training dataset
        val_dataset: the validation dataset
        test_dataset: the test dataset
        batch_size: the batch size
        num_workers: the number of workers
        seed: the seed

    Returns:
        the dictionary of dataloaders
    """
    dataloaders = {}
    dataloaders["train"] = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
        generator=torch.Generator().manual_seed(seed),
    )
    dataloaders["val"] = DataLoader(
        val_dataset,
        batch_size=2 * batch_size,
        pin_memory=True,
        num_workers=num_workers,
        generator=torch.Generator().manual_seed(seed),
    )
    dataloaders["test"] = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        pin_memory=True,
        num_workers=num_workers,
        generator=torch.Generator().manual_seed(seed),
    )
    return dataloaders

Let's see if everything works as expected.

In [5]:
model = CNN(num_layers=3, num_channels=32, num_classes=10)
print(model)
train_dataset, val_dataset, test_dataset = get_datasets(
    data_dir="data", train_size=50000
)
dataloaders = instantiate_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=test_dataset,
    batch_size=64,
    num_workers=0,
    seed=42,
)
print(next(iter(dataloaders["train"]))[0].shape)

CNN(
  (layers): ModuleList(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): AdaptiveAvgPool2d(output_size=1)
  )
)
torch.Size([64, 1, 28, 28])


Let's finally re-define the training loop.

In [6]:
@typechecked
def run_epoch(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: Callable,
    device: torch.device = torch.device("cpu"),
    optimizer: Optional[torch.optim.Optimizer] = None,
    train: bool = False,
) -> Tuple[float, float]:
    """Run one epoch.

    It runs one epoch of training, validation, or test, and returns the loss and accuracy. If training is True, it also updates the parameters.

    Args:
        model: the model (an instance of torch.nn.Module)
        dataloader: the dataloader (an instance of torch.utils.data.DataLoader)
        criterion: a callable that returns the loss given the logits and the labels
        device: the device (cpu or gpu). Defaults to torch.device("cpu").
        optimizer: the optimizer (an instance of torch.optim.Optimizer). Defaults to None.
        train: whether to train the model. Defaults to False.

    Returns:
        the loss and accuracy
    """
    epoch_loss = 0.0
    epoch_acc = 0.0

    # Move model to the device
    model = model.to(device)

    if train:
        model.train()
    else:
        model.eval()

    for x_batch, y_batch in dataloader:
        if train:
            optimizer.zero_grad()

        # Move data to the device
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        with torch.set_grad_enabled(train):
            pred = model(x_batch)
            loss = criterion(pred, y_batch)

            if train:
                loss.backward()
                optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += (pred.argmax(-1) == y_batch).sum().item()

    epoch_loss /= len(dataloader.dataset)
    epoch_acc /= len(dataloader.dataset)

    return epoch_loss, epoch_acc


@typechecked
def fit(
    model: torch.nn.Module,
    dataloaders: Dict[str, torch.utils.data.DataLoader],
    criterion: Callable,
    optimizer: torch.optim.Optimizer,
    num_epochs: int,
    patience: int,
    device: torch.device,
) -> OrderedDict[str, torch.Tensor]:
    """Train the model.

    Args:
        model: the model (an instance of torch.nn.Module)
        dataloaders: the dictionary of dataloaders (an instance of torch.utils.data.DataLoader)
        criterion: a callable that returns the loss given the logits and the labels
        optimizer: the optimizer (an instance of torch.optim.Optimizer)
        num_epochs: the number of epochs
        patience: the patience for early stopping
        device: the device (cpu or gpu)

    Returns:
        the state dict of the best model
    """
    loss_history = {"train": [], "val": []}
    acc_history = {"train": [], "val": []}

    best_val_acc = 0.0
    curr_patience = patience

    for epoch in range(num_epochs):
        # Training
        train_loss, train_acc = run_epoch(
            model, dataloaders["train"], criterion, device, optimizer, train=True
        )

        # Validation
        val_loss, val_acc = run_epoch(
            model, dataloaders["val"], criterion, device, train=False
        )

        loss = {"train": train_loss, "val": val_loss}
        acc = {"train": train_acc, "val": val_acc}
        print_epoch_summary(epoch, num_epochs, loss, acc)

        loss_history["train"].append(train_loss)
        loss_history["val"].append(val_loss)
        acc_history["train"].append(train_acc)
        acc_history["val"].append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            curr_patience = patience
            # Save the best model state dict as a checkpoint
            ckpt = copy.deepcopy(model.state_dict())
            # Save the best model to disk
            torch.save(ckpt, "ckpt.pt")
        else:
            curr_patience -= 1
            if curr_patience == 0:
                print("Early stopping")
                break

    training_history = {"loss": loss_history["train"], "acc": acc_history["train"]}
    validation_history = {"loss": loss_history["val"], "acc": acc_history["val"]}
    plot_curves(training_history, validation_history)

    return ckpt


@typechecked
def print_epoch_summary(
    epoch: int, num_epochs: int, loss: Dict[str, float], acc: Dict[str, float]
) -> None:
    """Print the epoch summary.

    The summary includes the epoch number, the number of epochs, the loss, and the accuracy.

    Args:
        epoch: the epoch number
        num_epochs: the number of epochs
        loss: the loss
        acc: the accuracy
    """
    print(
        f"Epoch {epoch+1:>{len(str(num_epochs))}}/{num_epochs} | "
        f"Train - loss: {loss['train']:.4f}, acc: {acc['train']:.4f} | "
        f"Val - loss: {loss['val']:.4f}, acc: {acc['val']:.4f}"
    )


@typechecked
def plot_curves(
    training_history: Dict[str, list], validation_history: Dict[str, list]
) -> None:
    """Plot the loss and accuracy curves.

    It plots the loss curve on the left and the accuracy curve on the right via matplotlib.

    Args:
        training_history: the training history
        validation_history: the validation history
    """
    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    ax1.plot(training_history["loss"], label="Training")
    ax1.plot(validation_history["loss"], label="Validation")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax2.plot(training_history["acc"], label="Training")
    ax2.plot(validation_history["acc"], label="Validation")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.legend()
    plt.show()

## Hyperparameters

In [7]:
# training
EPOCHS = 5
BATCH_SIZE = 64
PATIENCE = 3

# data
NUM_WORKERS = 3

# optimizer
LEARNING_RATE = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

# model
NUM_LAYERS = 3
NUM_CHANNELS = 32
NUM_CLASSES = 10

# reproducibility
SEED = 42

## Step 1: replace `nn.Module` with `LightningModule`

- 1.1: Model architecture goes in the `__init__` method
- 1.2: Prediction/inference logic goes in the `forward` hook
- 1.3: Optimizers go in the `configure_optimizers` hook
- 1.4: Training logic goes in the `training_step` hook
- 1.5: Validation logic goes in the `validation_step` hook
- 1.6: Test logic goes in the `test_step` hook
- 1.7: Remove any `cuda()` or `to(device)` calls
- 1.8: Instantiate the `LightningModule`

In [8]:
class CNNLit(pl.LightningModule):
    """Convolutional neural network model in PyTorch Lightning."""

    def __init__(
        self,
        num_layers: int,
        num_channels: int,
        num_classes: int,
        optimizer_args: Dict[str, Any],
    ) -> None:
        """Constructor method for CNNLit.

        Args:
            num_layers: the number of layers
            num_channels: the number of channels
            num_classes: the number of classes
        """
        super().__init__()
        # Step 1.1: Define the model architecture
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = 1 if i == 0 else num_channels
            out_channels = num_channels if i < num_layers - 1 else num_classes
            self.layers.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            self.layers.append(nn.ReLU())
        self.layers.append(nn.AdaptiveAvgPool2d(1))

        self.init_weights()

        self.optimizer_args = optimizer_args
        self.criterion = nn.CrossEntropyLoss()

    def init_weights(self) -> None:
        """Initialize the parameters.

        The weight is initialized using Xavier uniform initialization and the bias is initialized to zero.
        """
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: the input tensor

        Returns:
            the logits
        """
        # Step 1.2: Define the forward pass (inference/prediction logic)
        for layer in self.layers:
            x = layer(x)
        return x.view(x.size(0), -1)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the optimizer.

        Returns:
            the optimizer
        """
        # Step 1.3: Define the optimizer
        optimizer = torch.optim.SGD(self.parameters(), **self.optimizer_args)
        return optimizer

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index

        Returns:
            the loss
        """
        # Step 1.4: Define the training logic
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Validation step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        # Step 1.5: Define the validation logic
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        self.log("val_loss", loss)

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Test step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        # Step 1.6: Define the test logic
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        self.log("test_loss", loss)


# Step 1.7: Remember to remove any call to `cuda()` or `to(device)` in your code

In [9]:
# Step 1.8: Create an instance of the `LitModule` class
lit_module = CNNLit(
    num_layers=NUM_LAYERS,
    num_channels=NUM_CHANNELS,
    num_classes=NUM_CLASSES,
    optimizer_args={
        "lr": LEARNING_RATE,
        "momentum": MOMENTUM,
        "weight_decay": WEIGHT_DECAY,
    },
)

In [10]:
lit_module

CNNLit(
  (layers): ModuleList(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): AdaptiveAvgPool2d(output_size=1)
  )
  (criterion): CrossEntropyLoss()
)

## Step 2: replace the training loop with a `Trainer` instance

### Step 2.1: Training loop

Once the `LightningModule` is defined, we can train it using a `Trainer`.

- 1: Instantiate the `Trainer`
- 2: Call `trainer.fit(model, train_dataloader, val_dataloader)` to train the model

In [11]:
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu" if torch.cuda.is_available() else None,
)
trainer.fit(lit_module, dataloaders["train"], dataloaders["val"])

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name      | Type             | Params
-----------------------------------------------
0 | layers    | ModuleList       | 12.5 K
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
12.5 K    Trainable params
0         Non-trainable params
12.5 K    Total params
0.050     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


### Step 2.2: Test

In [12]:
trainer.test(lit_module, dataloaders["test"])

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.6059901118278503}]

## Step 3: replace the dataset and dataloaders with `LightningDataModule`

- 1: Move the dataset and dataloaders into a `LightningDataModule`
- 2: Instantiate the `LightningDataModule`
- 3: Pass the `LightningDataModule` to the `Trainer`

A `LightningDataModule` encapsulates the five steps involved in data processing in PyTorch:
- 2.1: Download / tokenize / process.
- 2.2: Clean and (maybe) save to disk.
- 2.3: Load inside Dataset.
- 2.4: Apply transforms (rotate, tokenize, etc…).
- 2.5: Wrap inside a DataLoader.

In [13]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./data",
        batch_size: int = 32,
        train_size: int = 50000,
        num_workers: int = 0,
    ) -> None:
        """Constructor method for MNISTDataModule.

        Args:
            data_dir: the directory to store the data. Defaults to "./data".
            batch_size: the batch size. Defaults to 32.
            train_size: the size of the training set. Defaults to 50000.
            num_workers: the number of workers. Defaults to 0.
        """
        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = 60_000 - train_size
        self.num_workers = num_workers

        self.transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

    def prepare_data(self) -> None:
        # Download
        torchvision.datasets.MNIST(self.data_dir, train=True, download=True)
        torchvision.datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str) -> None:
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_dev = torchvision.datasets.MNIST(
                self.data_dir, train=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = torch.utils.data.random_split(
                mnist_dev,
                [self.train_size, self.val_size],
                generator=torch.Generator().manual_seed(42),
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = torchvision.datasets.MNIST(
                self.data_dir, train=False, transform=self.transform
            )

        if stage == "predict":
            # The difference btw "predict" and "test" is that the former does not have labels
            pass  # We will not use this in this exercise

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            generator=torch.Generator().manual_seed(42),
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            generator=torch.Generator().manual_seed(42),
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            generator=torch.Generator().manual_seed(42),
        )

In [14]:
datamodule = MNISTDataModule(
    data_dir="data", batch_size=BATCH_SIZE, train_size=50000, num_workers=NUM_WORKERS
)

In [15]:
model = CNNLit(
    num_layers=NUM_LAYERS,
    num_channels=NUM_CHANNELS,
    num_classes=NUM_CLASSES,
    optimizer_args={
        "lr": LEARNING_RATE,
        "momentum": MOMENTUM,
        "weight_decay": WEIGHT_DECAY,
    },
)

In [16]:
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu" if torch.cuda.is_available() else None,
)
trainer.fit(model, datamodule)
trainer.test(model, datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name      | Type             | Params
-----------------------------------------------
0 | layers    | ModuleList       | 12.5 K
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
12.5 K    Trainable params
0         Non-trainable params
12.5 K    Total params
0.050     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.6921523213386536}]

## Logging

So far, we were only able to log the loss. However, Lightning allows us to have much more control over logging. For example, we can log the loss and accuracy after each epoch, and also log the loss and accuracy after each batch. We can even log images, audio, text, and arbitrary objects.

We will be using [TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/index.html) to compute metrics. TorchMetrics is a collection of metrics for PyTorch. It allows us to avoid writing boilerplate code for computing metrics like accuracy, precision, recall, etc.

In [17]:
from torchmetrics import Accuracy


class CNNLit(pl.LightningModule):
    """Convolutional neural network model in PyTorch Lightning."""

    def __init__(
        self,
        num_layers: int,
        num_channels: int,
        num_classes: int,
        optimizer_args: Dict[str, Any],
    ) -> None:
        """Constructor method for CNNLit.

        Args:
            num_layers: the number of layers
            num_channels: the number of channels
            num_classes: the number of classes
        """
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = 1 if i == 0 else num_channels
            out_channels = num_channels if i < num_layers - 1 else num_classes
            self.layers.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            self.layers.append(nn.ReLU())
        self.layers.append(nn.AdaptiveAvgPool2d(1))

        self.init_weights()

        self.optimizer_args = optimizer_args
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)

    def init_weights(self) -> None:
        """Initialize the parameters.

        The weight is initialized using Xavier uniform initialization and the bias is initialized to zero.
        """
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: the input tensor

        Returns:
            the logits
        """
        for layer in self.layers:
            x = layer(x)
        return x.view(x.size(0), -1)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the optimizer.

        Returns:
            the optimizer
        """
        optimizer = torch.optim.SGD(self.parameters(), **self.optimizer_args)
        return optimizer

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index

        Returns:
            the loss
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Validation step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Test step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True, on_step=False, on_epoch=True)

In [18]:
pl.seed_everything(SEED)  # reproducibility

model = CNNLit(
    num_layers=NUM_LAYERS,
    num_channels=NUM_CHANNELS,
    num_classes=NUM_CLASSES,
    optimizer_args={
        "lr": LEARNING_RATE,
        "momentum": MOMENTUM,
        "weight_decay": WEIGHT_DECAY,
    },
)
datamodule = MNISTDataModule(
    data_dir="data", batch_size=BATCH_SIZE, train_size=50000, num_workers=NUM_WORKERS
)

trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu" if torch.cuda.is_available() else None,
)

trainer.fit(model, datamodule)
trainer.test(model, datamodule)

Global seed set to 42
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | layers    | ModuleList         | 12.5 K
1 | criterion | CrossEntropyLoss   | 0     
2 | accuracy  | MulticlassAccuracy | 0     
-------------------------------------------------
12.5 K    Trainable params
0         Non-trainable params
12.5 K    Total params
0.050     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.604171872138977, 'test_acc': 0.8240000009536743}]

## Visualizing metrics

We can visualize the metrics logged by Lightning using [TensorBoard](https://www.tensorflow.org/tensorboard) or [Weights & Biases](https://wandb.ai/site). We will be using TensorBoard in this tutorial. To install TensorBoard, you can use `pip install tensorboard`. To start TensorBoard, you can use the following command:

```bash
tensorboard --logdir lightning_logs/
```

In [19]:
from torchmetrics import Accuracy


class CNNLit(pl.LightningModule):
    """Convolutional neural network model in PyTorch Lightning."""

    def __init__(
        self,
        num_layers: int,
        num_channels: int,
        num_classes: int,
        optimizer_args: Dict[str, Any],
    ) -> None:
        """Constructor method for CNNLit.

        Args:
            num_layers: the number of layers
            num_channels: the number of channels
            num_classes: the number of classes
        """
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = 1 if i == 0 else num_channels
            out_channels = num_channels if i < num_layers - 1 else num_classes
            self.layers.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            self.layers.append(nn.ReLU())
        self.layers.append(nn.AdaptiveAvgPool2d(1))

        self.init_weights()

        self.optimizer_args = optimizer_args
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)

    def init_weights(self) -> None:
        """Initialize the parameters.

        The weight is initialized using Xavier uniform initialization and the bias is initialized to zero.
        """
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: the input tensor

        Returns:
            the logits
        """
        for layer in self.layers:
            x = layer(x)
        return x.view(x.size(0), -1)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the optimizer.

        Returns:
            the optimizer
        """
        optimizer = torch.optim.SGD(self.parameters(), **self.optimizer_args)
        return optimizer

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index

        Returns:
            the loss
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("train_loss", loss, prog_bar=True)
        self.log(
            "train_acc", acc, prog_bar=True, on_step=False, on_epoch=True, logger=True
        )
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Validation step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("val_loss", loss, prog_bar=True)
        self.log(
            "val_acc", acc, prog_bar=True, on_step=False, on_epoch=True, logger=True
        )

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Test step.

        This method is called for every batch.

        Args:
            batch: the batch
            batch_idx: the batch index
        """
        x_batch, y_batch = batch
        pred = self(x_batch)
        loss = self.criterion(pred, y_batch)
        acc = self.accuracy(pred, y_batch)
        self.log("test_loss", loss, prog_bar=True)
        self.log(
            "test_acc", acc, prog_bar=True, on_step=False, on_epoch=True, logger=True
        )

In [20]:
pl.seed_everything(SEED)  # reproducibility

model = CNNLit(
    num_layers=NUM_LAYERS,
    num_channels=NUM_CHANNELS,
    num_classes=NUM_CLASSES,
    optimizer_args={
        "lr": LEARNING_RATE,
        "momentum": MOMENTUM,
        "weight_decay": WEIGHT_DECAY,
    },
)
datamodule = MNISTDataModule(
    data_dir="data", batch_size=BATCH_SIZE, train_size=50000, num_workers=NUM_WORKERS
)

logger = pl.loggers.TensorBoardLogger("lightning_logs/", name="mnist")

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="gpu" if torch.cuda.is_available() else None,
    logger=logger,
)

trainer.fit(model, datamodule)
trainer.test(model, datamodule)

Global seed set to 42
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs/mnist

  | Name      | Type               | Params
-------------------------------------------------
0 | layers    | ModuleList         | 12.5 K
1 | criterion | CrossEntropyLoss   | 0     
2 | accuracy  | MulticlassAccuracy | 0     
-------------------------------------------------
12.5 K    Trainable params
0         Non-trainable params
12.5 K    Total params
0.050     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.12229259312152863, 'test_acc': 0.9617999792098999}]

## Checkpointing

Lightning automatically saves checkpoints of your model at the end of every epoch. If training is interrupted, you can resume from the last saved checkpoint.

In [21]:
# Retrieve the best model
best_model_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(best_model_path)
model.load_state_dict(ckpt["state_dict"])

trainer.test(model, datamodule)

Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.12229259312152863, 'test_acc': 0.9617999792098999}]