<a href="https://colab.research.google.com/github/rahiakela/deep-learning-research-and-practice/blob/main/deep-learning-fundamentals/unit05-lightning/04_lightning_with_datamodules.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## PyTorch & Lightning with Data Modules

**Reference**

[Organizing Your Data Loaders with Data Modules](https://lightning.ai/pages/courses/deep-learning-fundamentals/overview-organizing-your-code-with-pytorch-lightning/5-5-organizing-your-data-loaders-with-data-modules/)

## 1) Setup

In [None]:
!pip install torch torchvision torchaudio
!pip install lightning
!pip install torchmetrics

In [3]:
!lightning --version

lightning, version 2.0.1


In [4]:
import torch
import torchvision
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import lightning as L
import torchmetrics as tm

import numpy as np
from collections import Counter

%matplotlib inline
import matplotlib.pyplot as plt

##2) Data Modules

In [14]:
class MNISTDataModule(L.LightningDataModule):

  def __init__(self, data_dir="mnist", batch_size=64):
    super().__init__()

    self.data_dir = data_dir
    self.batch_size = batch_size

  def prepare_data(self):
    # download MNIST dataset
    datasets.MNIST(self.data_dir, train=True, download=True)
    datasets.MNIST(self.data_dir, train=False, download=True)

  def setup(self, stage: str):
    self.mnist_test = datasets.MNIST(self.data_dir, transform=transforms.ToTensor(), train=False)
    self.mnist_predict = datasets.MNIST(self.data_dir, transform=transforms.ToTensor(), train=False)

    mnist_full = datasets.MNIST(self.data_dir, transform=transforms.ToTensor(), train=True)
    self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, drop_last=True)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle=False)

  def test_dataloader(self):
    return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)

  def predict_dataloader(self):
    return DataLoader(self.mnist_predict, batch_size=self.batch_size, shuffle=False)

## 3) Implementing the model

In [17]:
class PyTorchMLP(torch.nn.Module):
  def __init__(self, num_features, num_classes):
    super().__init__()

    self.all_layers = torch.nn.Sequential(
      # 1st hidden layer
      torch.nn.Linear(num_features, 50),
      torch.nn.ReLU(),
      # 2nd hidden layer
      torch.nn.Linear(50, 25),
      torch.nn.ReLU(),
      # output layer
      torch.nn.Linear(25, num_classes),
    )

  def forward(self, x):
    x = torch.flatten(x, start_dim=1)
    logits = self.all_layers(x)
    return logits

In [15]:
class LightningModel(L.LightningModule):
    def __init__(self, model, learning_rate):
      super().__init__()

      self.learning_rate = learning_rate
      self.model = model

      # Set up attributes for computing the accuracy
      self.train_accuracy = tm.Accuracy(task="multiclass", num_classes=10)
      self.val_accuracy = tm.Accuracy(task="multiclass", num_classes=10)

      # Set up attribute for computing the test set accuracy
      self.test_accuracy = tm.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
      return self.model(x)

    def train_model(self, batch):
      features, true_labels = batch
      logits = self(features)

      loss = F.cross_entropy(logits, true_labels)
      predicted_labels = torch.argmax(logits, dim=1)
      return loss, true_labels, predicted_labels

    def training_step(self, batch, batch_idx):
      loss, true_labels, predicted_labels = self.train_model(batch)
      self.log("train_loss", loss)

      # Computes train accuracy on whole train set
      self.train_accuracy(predicted_labels, true_labels)
      self.log("train_accuracy", self.train_accuracy, prog_bar=True, on_epoch=True, on_step=False)

      # this is passed to the optimizer for training
      return loss

    def validation_step(self, batch, batch_idx):
      loss, true_labels, predicted_labels = self.train_model(batch)
      self.log("val_loss", loss, prog_bar=True)

      # Computes validation accuracy on whole validation set
      self.val_accuracy(predicted_labels, true_labels)
      self.log("val_accuracy", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
      loss, true_labels, predicted_labels = self.train_model(batch)

      # Computes test accuracy on whole test set
      self.test_accuracy(predicted_labels, true_labels)
      self.log("accuracy", self.test_accuracy)

    def configure_optimizers(self):
      optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
      return optimizer

## 4) The training model

In [22]:
torch.manual_seed(123)

print("Torch CUDA available?", torch.cuda.is_available())

data_module = MNISTDataModule()

pytorch_model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
  max_epochs=10,
  accelerator="auto", # set to "auto" or "gpu" to use GPUs if available
  devices="auto",      # Uses all available GPUs if applicable
  deterministic=True
)

trainer.fit(
  model=lightning_model,
  datamodule=data_module
)

# Evaluate model based on test_step
train_acc = trainer.validate(datamodule=data_module)[0]["val_accuracy"]
val_acc = trainer.validate(datamodule=data_module)[0]["val_accuracy"]
test_acc = trainer.test(datamodule=data_module)[0]["accuracy"]

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Torch CUDA available? False


INFO: 
  | Name           | Type               | Params
------------------------------------------------------
0 | model          | PyTorchMLP         | 40.8 K
1 | train_accuracy | MulticlassAccuracy | 0     
2 | val_accuracy   | MulticlassAccuracy | 0     
3 | test_accuracy  | MulticlassAccuracy | 0     
------------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name           | Type               | Params
------------------------------------------------------
0 | model          | PyTorchMLP         | 40.8 K
1 | train_accuracy | MulticlassAccuracy | 0     
2 | val_accuracy   | MulticlassAccuracy | 0     
3 | test_accuracy  | MulticlassAccuracy | 0     
------------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimate

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt


Validation: 0it [00:00, ?it/s]

INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt


Validation: 0it [00:00, ?it/s]

INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_5/checkpoints/epoch=9-step=8590.ckpt


Testing: 0it [00:00, ?it/s]

In [23]:
print(f"Train Accuracy {train_acc*100:.2f}% | Val Accuracy {val_acc*100:.2f}% | Test Accuracy {test_acc*100:.2f}%")

Train Accuracy 97.52% | Val Accuracy 97.46% | Test Accuracy 96.99%


Train Accuracy 98.45% | Val Accuracy 96.84% | Test Accuracy 97.00%

In [24]:
# save the model
PATH = "lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

In [25]:
# load the model
model = PyTorchMLP(num_features=784, num_classes=10)
model.load_state_dict(torch.load(PATH))
model.eval()

PyTorchMLP(
  (all_layers): Sequential(
    (0): Linear(in_features=784, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=25, bias=True)
    (3): ReLU()
    (4): Linear(in_features=25, out_features=10, bias=True)
  )
)