# Finetuning a Pytorch Lightning Image Classifier

This example introduces how to train a Pytorch Lightning Module using AIR {class}`LightningTrainer <ray.train.lightning.LightningTrainer>`. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.


In [1]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.core import datamodule
from pytorch_lightning.loggers.csv_logs import CSVLogger

from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SMOKE_TEST = True

## Prepare Dataset and Module

The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer. 

In [3]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        # split data into train and val sets
        mnist = MNIST(
            self.data_dir, train=True, download=True, transform=self.transform
        )
        self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

        self.mnist_test = MNIST(
            self.data_dir, train=False, download=True, transform=self.transform
        )
        
        if SMOKE_TEST:
            self.mnist_train = Subset(self.mnist_train, range(5000))
            self.mnist_val = Subset(self.mnist_val, range(1000))


    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

datamodule = MNISTDataModule(batch_size=128)

Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`.

In [4]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr, feature_dim):
        super(MNISTClassifier, self).__init__()
        self.fc1 = torch.nn.Linear(28 * 28, feature_dim)
        self.fc2 = torch.nn.Linear(feature_dim, 10)
        self.lr = lr
        self.accuracy = Accuracy()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## Define the Cofigurations for AIR LightningTrainer

The {meth}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>` class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.

- The `.module()` method takes a subclass of `pl.LightningModule` and its initialization parameters. `LightningTrainer` will instantiate a model instance internally in the training loop.
- The `.trainer()` method takes the initialization parameters of `ptl.Trainer`. You can specify training configurations, loggers, and callbacks here.
- The `.fit_params()` method stores all the parameters that will be passed into `ptl.Trainer.fit()`, including train/val dataloaders, datamodules, and checkpoint paths.
- The `.checkpointing()` method saves the configurations for a `ModelCheckpoint` callback. Note that the `LightningTrainer` reports the latest metrics to the AIR session when a new checkpoint is saved.
- The `.build()` method converts the configurations into a dictionary that is readable for `LightningTrainer`.

In [5]:
lightning_config = (
    LightningConfigBuilder()
    .module(
        MNISTClassifier, feature_dim=128, lr=0.001
    )
    .trainer(max_epochs=7, accelerator="cpu", log_every_n_steps=100, logger=CSVLogger("logs"))
    .fit_params(datamodule=datamodule)
    .checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
    .build()
)

In [6]:
scaling_config = ScalingConfig(
    num_workers=2, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

run_config = RunConfig(
    name="ptl-mnist-example",
    local_dir="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(num_to_keep=3, checkpoint_score_attribute="val_accuracy", checkpoint_score_order="max")
)

In [7]:
if SMOKE_TEST:
    scaling_config = ScalingConfig(
        num_workers=2, use_gpu=False, resources_per_worker={"CPU": 1}
    )

In [8]:
trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [9]:
result = trainer.fit()
print(result)
print(result.metrics["val_accuracy"])

find: ‘.git’: No such file or directory
2023-03-20 18:42:05,508	INFO worker.py:1362 -- Connecting to existing Ray cluster at address: 10.0.36.162:6379...
2023-03-20 18:42:05,518	INFO worker.py:1556 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://console.anyscale-staging.com/api/v2/sessions/ses_lakhijrn6mdv9hqpld5trwzlxz/services?redirect_to=dashboard [39m[22m
2023-03-20 18:42:05,668	INFO packaging.py:346 -- Pushing file package 'gcs://_ray_pkg_8ae886bbd7be749872ca0e248c38b60f.zip' (52.43MiB) to Ray cluster...
2023-03-20 18:42:06,553	INFO packaging.py:359 -- Successfully pushed file package 'gcs://_ray_pkg_8ae886bbd7be749872ca0e248c38b60f.zip'.
  "The `local_dir` argument of `Experiment is deprecated. "


0,1
Current time:,2023-03-20 18:42:25
Running for:,00:00:18.88
Memory:,4.4/61.4 GiB

Trial name,status,loc,iter,total time (s),train_loss,val_loss,val_accuracy
LightningTrainer_966f8_00000,TERMINATED,10.0.36.162:11733,7,9.77929,0.185041,-6.80121,0.931573


(RayTrainWorker pid=11984) 2023-03-20 18:42:16,052	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=11984) GPU available: False, used: False
(RayTrainWorker pid=11984) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=11984) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=11984) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=11985) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=11984) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=11984) 
(RayTrainWorker pid=11984)   | Name     | Type     | Params
(RayTrainWorker pid=11984) --------------------------------------
(RayTrainWorker pid=11984) 0 | fc1      | Linear   | 100 K 
(RayTrainWorker pid=11984) 1 | fc2      | Linear   | 1.3 K 
(RayTrainWorker pid=11984) 2 | accuracy | Accuracy | 0     
(RayTrainWorker pid=11984) --------------------------------------
(RayTrainWorker pid=11984) 101 K     Trainable params
(RayTrainWorker pid

Trial name,_report_on,date,done,epoch,experiment_tag,hostname,iterations_since_restore,node_ip,pid,should_checkpoint,step,time_since_restore,time_this_iter_s,time_total_s,timestamp,train_loss,training_iteration,trial_id,val_accuracy,val_loss
LightningTrainer_966f8_00000,train_epoch_end,2023-03-20_18-42-22,True,6,0,ip-10-0-36-162,7,10.0.36.162,11733,True,140,9.77929,0.698755,9.77929,1679362942,0.185041,7,966f8_00000,0.931573,-6.80121


2023-03-20 18:42:25,603	INFO tune.py:826 -- Total run time: 18.96 seconds (18.88 seconds for the tuning loop).


Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.18504135310649872, 'val_loss': -6.801209926605225, 'val_accuracy': 0.9315732717514038, 'epoch': 6, 'step': 140, 'should_checkpoint': True, 'done': True, 'trial_id': '966f8_00000', 'experiment_tag': '0'},
  log_dir=PosixPath('/tmp/ray_results/ptl-mnist-example/LightningTrainer_966f8_00000_0_2023-03-20_18-42-08'),
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_966f8_00000_0_2023-03-20_18-42-08/checkpoint_000006)
)
0.9315732717514038


## Test the network on the test data