## Notebook: Comparison of Native PyTorch Training and PyTorch Lightning

This notebook is for comparing how training is implemented in Native PyTorch and PyTorch Lightning. The aim of the notebook is to show what PyTorch Lightning does under the hood and how it simplifies the training process. The notebook contains the following sections:
- Generate Data: Generate a simple dataset for training
- Native PyTorch Training: Implement training in Native PyTorch
- PyTorch Lightning Training: Implement training in PyTorch Lightning

In [None]:
import lightning
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

import demo.lightning.models
import demo.torch.models
from demo.torch.utils import get_optimizer, get_scheduler, train_regression_model

torch.set_float32_matmul_precision("high")
torch.manual_seed(42)
np.random.seed(42)

## Generate Data: Generate a simple dataset for training

In this section, we generate a simple dataset for training. The dataset is generated using a linear function x=y with some noise. The dataset contains 500 samples with just one feature and one target. Since the objective here is to demonstrate the difference between Native PyTorch and PyTorch Lightning, a simple dataset is sufficient and will serve the purpose better.

In [None]:
def visualize(x, y, y_pred=None):
    """Scatter plot of the data points and the predictions"""

    _, ax = plt.subplots(figsize=(6, 6))
    sns.scatterplot(x=x.flatten(), y=y.flatten(), label="Observations", ax=ax)
    sns.lineplot(
        x=x.flatten(), y=x.flatten(), linestyle="--", color="red", label="Ground truth"
    )
    if y_pred is not None:
        sns.scatterplot(
            x=x.flatten(), y=y_pred.flatten(), color="green", label="Predictions"
        )
    ax.set(xlabel="X", ylabel="Y", title="Simple linear regression dataset")
    plt.show()

In [None]:
# Prepare a very simple dataset for linear regression
x = torch.linspace(0, 1, 500).view(-1, 1)
y = x + 0.2 * torch.randn_like(x)

visualize(x, y)

## Native PyTorch Training: Implement training in Native PyTorch

In this section, we implement the training process in Native PyTorch. The training process includes the following steps:
- Define the model.
- Define the loss function.
- Define the optimizer.
- Define the scheduler.
- Define the dataloaders.
- Define the training loop.

In [None]:
torch_model = demo.torch.models.LinearRegression(in_features=1, out_features=1)
loss_fn = torch.nn.functional.mse_loss
optimizer = get_optimizer(torch_model)
scheduler = get_scheduler(optimizer)

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y), batch_size=16, shuffle=True
)
# While in practice we use a different dataset for validation, here we use the same for simplicity
val_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y), batch_size=16, shuffle=False
)

train_regression_model(
    model=torch_model,
    loss_fn=loss_fn,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    n_epochs=5,
    device=torch.device("cuda"),
)

In [None]:
with torch.no_grad():
    y_pred = torch_model(x.to("cuda")).detach().cpu()

visualize(x, y, y_pred)

## PyTorch Lightning Training: Implement training in PyTorch Lightning

In this section, we implement the training process in PyTorch Lightning. The training process includes the following steps:
- Define a lightning module
- Define a lightning data module
- Define a pytorch lightning trainer
- Run the trainer

In [None]:
class DataModule(lightning.LightningDataModule):
    """ PyTorch Lightning data module """

    def __init__(self, x: torch.tensor, y: torch.tensor, batch_size: int):
        """ Data module for the simple linear regression dataset """

        super().__init__()

        self.train_dataset = torch.utils.data.TensorDataset(x, y)
        self.val_dataset = torch.utils.data.TensorDataset(x, y)
        self.batch_size = batch_size

    def train_dataloader(self):
        """ Training data loader """

        return torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        """ Validation data loader """

        return torch.utils.data.DataLoader(
            self.val_dataset, batch_size=self.batch_size, shuffle=False
        )

lightning.pytorch.seed_everything(42, workers=True)

lightning_model = demo.lightning.models.LinearRegression(in_features=1, out_features=1)
data_module = DataModule(x, y, batch_size=16)
trainer = lightning.pytorch.Trainer(max_epochs=5, log_every_n_steps=1)

trainer.fit(model=lightning_model, datamodule=data_module)

In [None]:
y_pred = lightning_model(x).detach().cpu()

visualize(x, y, y_pred)