# Basic Lightning Module

This chapter implements a basic Pytorch Lightning module. It is based on the Lightning documentation [LIGHTNINGMODULE](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html).

A `LightningModule` organizes your `PyTorch` code into six sections:

* Initialization (`__init__` and `setup()`).
* Train Loop (`training_step()`)
* Validation Loop (`validation_step()`)
* Test Loop (`test_step()`)
* Prediction Loop (`predict_step()`)
* Optimizers and LR Schedulers (`configure_optimizers()`)

A LightningModule is a torch.nn.Module but with added functionality.


## Starter Example

In [None]:
import lightning as L
import torch
import torch.nn.functional as F
from torchmetrics.functional import accuracy

from lightning.pytorch.demos import Transformer


class LightningBasic(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        # set up a dense model for the mnist dataset
        self.model = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 10),
        )

    def forward(self, x, y):
        # set up for the forward pass and the mnist dataset
        x = x.view(x.size(0), -1)
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        return loss, acc

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x, y)
        return y_hat

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.02)

In [None]:
import torch
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
model = LightningBasic(vocab_size=28)

# train with both splits
trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader)

In [None]:

dataloader = DataLoader(dataset)
trainer = L.Trainer()
trainer.fit(model=model, train_dataloaders=dataloader)

# automatically loads the best weights for you
trainer.test(model)

## Additional Methods

### Train Epoch-level Operations

* In the case that you need to make use of all the outputs from each training_step(), override the on_train_epoch_end() method.