# Finetuning a Pytorch Lightning Image Classifier

This example introduces how to train a Pytorch Lightning Module using AIR {class}`LightningTrainer <ray.train.lightning.LightningTrainer>`. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.


In [75]:
SMOKE_TEST = True

In [76]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.core import datamodule
from pytorch_lightning.loggers.csv_logs import CSVLogger

seed = 420
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f45609f9670>

## Prepare Dataset and Module

The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer. 

In [77]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        # split data into train and val sets
        mnist = MNIST(
            self.data_dir, train=True, download=True, transform=self.transform
        )
        self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        self.mnist_test = MNIST(
            self.data_dir, train=False, download=True, transform=self.transform
        )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

datamodule = MNISTDataModule(batch_size=128)

Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`.

In [78]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        return {"val_loss": loss, "val_accuracy": acc}
    
    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.log("test_accuracy", acc)
        return {"test_loss": loss, "test_accuracy": acc}
    
    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

You don't need to make any change to the definition of PyTorch Lightning model and datamodule.

## Define the Cofigurations for AIR LightningTrainer

The {meth}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>` class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.

- The `.module()` method takes a subclass of `pl.LightningModule` and its initialization parameters. `LightningTrainer` will instantiate a model instance internally in the workers' training loop.
- The `.trainer()` method takes the initialization parameters of `pl.Trainer`. You can specify training configurations, loggers, and callbacks here.
- The `.fit_params()` method stores all the parameters that will be passed into `pl.Trainer.fit()`, including train/val dataloaders, datamodules, and checkpoint paths.
- The `.checkpointing()` method saves the configurations for a `ModelCheckpoint` callback. Note that the `LightningTrainer` reports the latest metrics to the AIR session when a new checkpoint is saved.
- The `.build()` method generates a dictionary that contains all the configurations in the builder. This dictionary will be passed to `LightningTrainer` later.

In [79]:
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder, LightningCheckpoint

lightning_config = (
    LightningConfigBuilder()
    .module(
        MNISTClassifier, lr=1e-3
    )
    .trainer(max_epochs=10, accelerator="cpu", log_every_n_steps=100, logger=CSVLogger("logs"))
    .fit_params(datamodule=datamodule)
    .checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
    .build()
)

In [80]:
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

run_config = RunConfig(
    name="ptl-mnist-example",
    local_dir="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(num_to_keep=3, checkpoint_score_attribute="val_accuracy", checkpoint_score_order="max")
)

In [81]:
if SMOKE_TEST:
    scaling_config = ScalingConfig(
        num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1}
    )

In [82]:
trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

Now fit your trainer:

In [83]:
result = trainer.fit()
print("Validation Accuracy: ", result.metrics["val_accuracy"])
result

(pid=233131)   from pandas import MultiIndex, Int64Index


== Status ==
Current time: 2023-03-23 16:23:21 (running for 00:00:07.95)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray_results/ptl-mnist-example
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+--------------------+
| Trial name                   | status   | loc                |
|------------------------------+----------+--------------------|
| LightningTrainer_aef1f_00000 | RUNNING  | 10.0.61.115:233131 |
+------------------------------+----------+--------------------+




(RayTrainWorker pid=233458) 2023-03-23 16:23:25,524	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]


== Status ==
Current time: 2023-03-23 16:23:26 (running for 00:00:12.95)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray_results/ptl-mnist-example
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+--------------------+
| Trial name                   | status   | loc                |
|------------------------------+----------+--------------------|
| LightningTrainer_aef1f_00000 | RUNNING  | 10.0.61.115:233131 |
+------------------------------+----------+--------------------+




(RayTrainWorker pid=233460)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=233458)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=233461)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=233459)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=233458) GPU available: False, used: False
(RayTrainWorker pid=233458) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=233458) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=233458) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=233458) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=233458) 
(RayTrainWorker pid=233458)   | Name              | Type       | Params
(RayTrainWorker pid=233458) -------------------------------------------------
(RayTrainWorker pid=233458) 0 | linear_relu_stack | Sequential | 101 K 
(RayTrainWorker pid=233458) 1 | accuracy          | Accuracy   | 0     
(RayTrainWorker pid=233458) -------------------------------

== Status ==
Current time: 2023-03-23 16:23:31 (running for 00:00:17.95)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray_results/ptl-mnist-example
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+--------------------+
| Trial name                   | status   | loc                |
|------------------------------+----------+--------------------|
| LightningTrainer_aef1f_00000 | RUNNING  | 10.0.61.115:233131 |
+------------------------------+----------+--------------------+






Result for LightningTrainer_aef1f_00000:
  _report_on: train_epoch_end
  date: 2023-03-23_16-23-34
  done: false
  epoch: 0
  hostname: ip-10-0-61-115
  iterations_since_restore: 1
  node_ip: 10.0.61.115
  pid: 233131
  should_checkpoint: true
  step: 108
  time_since_restore: 12.394395589828491
  time_this_iter_s: 12.394395589828491
  time_total_s: 12.394395589828491
  timestamp: 1679613814
  train_loss: 0.7441539168357849
  training_iteration: 1
  trial_id: aef1f_00000
  val_accuracy: 0.7299904227256775
  val_loss: -6.072018623352051
  
== Status ==
Current time: 2023-03-23 16:23:38 (running for 00:00:25.00)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray_results/ptl-mnist-example
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+--------------------+--------+------------------+--------------+----------------+------------+
| Trial name                   | status   | loc                |   iter |   total

2023-03-23 16:23:57,592	INFO tune.py:817 -- Total run time: 43.72 seconds (43.71 seconds for the tuning loop).


Validation Accuracy:  0.7734813690185547


Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.33654841780662537, 'val_accuracy': 0.7734813690185547, 'val_loss': -9.297370910644531, 'epoch': 9, 'step': 1080, 'should_checkpoint': True, 'done': True, 'trial_id': 'aef1f_00000', 'experiment_tag': '0'},
  log_dir=PosixPath('/tmp/ray_results/ptl-mnist-example/LightningTrainer_aef1f_00000_0_2023-03-23_16-23-13'),
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_aef1f_00000_0_2023-03-23_16-23-13/checkpoint_000009)
)

## Test your network on the test data

Next, we use PyTorch Lightning's native interface to evaluate the best model: To run the test loop using the ``pl.LightningModule.test_step()`` in your user-defined code, simply pass the loaded model to ``pl.Trainer.test()``. 

For faster inference on large datasets, you can try to use AIR's {class}`BatchPredictor <ray.train.batch_prediction.BatchPredictor>`.

In [84]:
checkpoint: LightningCheckpoint = result.checkpoint
best_model: pl.LightningModule = checkpoint.get_model(MNISTClassifier)

In [85]:
trainer = pl.Trainer()
test_dataloader = datamodule.test_dataloader()
result = trainer.test(best_model, dataloaders=test_dataloader)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: 0it [00:00, ?it/s]

You can also use `LightningPredictor` for inference:

In [86]:
from ray.train.lightning import LightningPredictor
predictor = LightningPredictor.from_checkpoint(checkpoint, MNISTClassifier, use_gpu=False)

def accuracy(logits, labels):
    preds = np.argmax(logits, axis=1)
    correct_preds = np.sum(preds == labels)
    return correct_preds

corrects = total = 0
for batch in test_dataloader:
    inputs, labels = batch
    inputs, labels = inputs.numpy(), labels.numpy()
    logits = predictor.predict(inputs)["predictions"]
    total += labels.size
    corrects += accuracy(logits, labels)
    
print("Accuracy: ", corrects / total)

Accuracy:  0.7709
