https://www.geeksforgeeks.org/deep-learning/training-neural-networks-using-pytorch-lightning/

PyTorch Lightning is a library that provides a high-level interface for PyTorch. Problem with PyTorch is that every time you start a project you have to rewrite those training and testing loop. PyTorch Lightning fixes the problem by not only reducing boilerplate code but also providing added functionality that might come handy while training your neural networks. One of the things I love about Lightning is that the code is very organized and reusable, and not only that but it reduces the training and testing loop while retain the flexibility that PyTorch is known for. And once you learn how to use it you'll see how similar the code is to that of PyTorch.

In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torch import nn
import torch.nn.functional as F
from torch.optim import SGD

import lightning.pytorch as lpt

In [2]:
mnist_path = "../assets"

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

train = datasets.MNIST(mnist_path, train=True, download=True, transform=transform)
test = datasets.MNIST(mnist_path, train=False, download=True, transform=transform)

trainloader = DataLoader(train, batch_size=32, shuffle=True)
testloader = DataLoader(test, batch_size=32, shuffle=True)

In [4]:
class Data(lpt.LightningDataModule):
    def prepare_data(self):
        transform = transforms.Compose([transforms.ToTensor()])

        self.train_data = datasets.MNIST(
            mnist_path, train=True, download=True, transform=transform
        )
        self.test_data = datasets.MNIST(
            mnist_path, train=False, download=True, transform=transform
        )

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=32, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_data, batch_size=32, shuffle=True)

In [5]:
class model(lpt.LightningModule):
    def __init__(self):
        super(model, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.out = nn.Linear(128, 10)
        self.lr = 0.01
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

    def configure_optimizers(self):
        return SGD(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)
        return loss

    def validation_step(self, valid_batch, batch_idx):
        x, y = valid_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)

In [6]:
# Create Model Object
clf = model()
# Create Data Module Object
mnist = Data()
# Create Trainer Object
trainer = lpt.Trainer(accelerator="gpu", max_epochs=5)
trainer.fit(clf, mnist)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/home/volody/code/study-py/ts-pytorch/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 5060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precis

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/volody/code/study-py/ts-pytorch/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:485: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/volody/code/study-py/ts-pytorch/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/volody/code/study-py/ts-pytorch/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
