# Using PyTorch Lightning with Tune

(tune-pytorch-lightning-ref)=

PyTorch Lightning is a framework which brings structure into training PyTorch models. It
aims to avoid boilerplate code, so you don't have to write the same training
loops all over again when building a new model.

```{image} /images/pytorch_lightning_full.png
:align: center
```

The main abstraction of PyTorch Lightning is the `LightningModule` class, which
should be extended by your application. There is [a great post on how to transfer your models from vanilla PyTorch to Lightning](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09).

The class structure of PyTorch Lightning makes it very easy to define and tune model
parameters. This tutorial will show you how to use Tune with AIR {class}`LightningTrainer <ray.train.lightning.LightningTrainer>` to find the best set of
parameters for your application on the example of training a MNIST classifier. Notably,
the `LightningModule` does not have to be altered at all for this - so you can
use it plug and play for your existing models, assuming their parameters are configurable!

:::{note}
If you don't want to use AIR {class}`LightningTrainer <ray.train.lightning.LightningTrainer>`, please refer to this document: {ref}`Using vanilla Pytorch Lightning with Tune <tune-vanilla-pytorch-lightning-ref>`.

:::

:::{note}
To run this example, you will need to install the following:

```bash
$ pip install "ray[tune]" torch torchvision pytorch-lightning
```
:::

```{contents}
:backlinks: none
:local: true
```

## PyTorch Lightning classifier for MNIST

Let's first start with the basic PyTorch Lightning implementation of an MNIST classifier.
This classifier does not include any tuning code at this point.

First, we run some imports:

In [21]:
import os
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.loggers.csv_logs import CSVLogger

import ray
import ray.tune as tune
from ray.air.config import CheckpointConfig, ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.tune.schedulers import PopulationBasedTraining

Our example builds on the MNIST example from the [blog post](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09) we mentioned before. We adapted the original model and dataset definitions into `MNISTClassifier` and `MNISTDataModule`. 

In [23]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, config):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy()
        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.lr = config["lr"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


In [24]:
default_config = {
    "layer_1_size": 128,
    "layer_2_size": 256,
    "lr": 1e-3,
}

## Tuning the model parameters

The parameters above should give you a good accuracy of over 90% already. However, we might improve on this simply by changing some of the hyperparameters. For instance, maybe we get an even higher accuracy if we used a smaller learning rate and larger middle layer size.

Instead of manually loop through all the parameter combinitions, let's use Tune to systematically try out parameter combinations and find the best performing set.

First, we need some additional imports:

In [25]:
from pytorch_lightning.loggers import TensorBoardLogger
from ray import air, tune
from ray.air import session
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining

### 1. Configuring the search space

Now we configure the parameter search space using {class}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>`. We would like to choose between three different layer and batch sizes. The learning rate should be sampled uniformly between `0.0001` and `0.1`. The `tune.loguniform()` function is syntactic sugar to make sampling between these different orders of magnitude easier, specifically we are able to also sample small values.

:::{note}
In `LightningTrainer`, the frequency of metric reporting is the same as the frequency of checkpointing specified in `LightningConfigBuilder`. For example, if you set `builder.checkpointing(..., every_n_epochs=2)`, then for every 2 epochs, all the latest metrics together will be reported to the tune session together with a checkpoint.

:::


:::{note}
`LightningConfigBuilder.checkpointing()` specifies the monitor metric and checkpoint frequency in Lightning ModelCheckpoint manner. You should also provide a AIR `CheckpointConfig` to properly save top-k checkpoints under the trial folder. Otherwise, LightningTrainer copies and saves all checkpoints by default.

:::

In [26]:

num_epochs = 10
dm = MNISTDataModule(batch_size=128)

config = {
    "layer_1_size": tune.choice([32, 64, 128]),
    "layer_2_size": tune.choice([64, 128, 256]),
    "lr": tune.loguniform(1e-4, 1e-1),
}

lightning_config = (
    LightningConfigBuilder()
    .module(cls=MNISTClassifier, config=config)
    .trainer(max_epochs=num_epochs, accelerator="cpu")
    .fit_params(datamodule=dm)
    .checkpointing(monitor="ptl/val_accuracy", save_top_k=2, mode="max")
    .build()
)

# Make sure to also define AIR CheckpointConfig here in order to save 
run_config=RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="ptl/val_accuracy",
        checkpoint_score_order="max",
    ),
)

### 2. Selecting a scheduler

In this example, we use an [Asynchronous Hyperband](https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/)
scheduler. This scheduler decides at each iteration which trials are likely to perform
badly, and stops these trials. This way we don't waste any resources on bad hyperparameter
configurations.

In [27]:
scheduler = ASHAScheduler(
    max_t=num_epochs,
    grace_period=1,
    reduction_factor=2)

### 3. Changing the CLI output

We instantiate a `CLIReporter` to specify which metrics we would like to see in our
output tables in the command line. This is optional, but can be used to make sure our
output tables only include information we would like to see.

In [28]:
reporter = CLIReporter(
    parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
    metric_columns=["ptl/val_loss", "ptl/val_accuracy", "training_iteration"]
)

### 4. Training with GPUs

We can specify the number of resources, including GPUs, that Tune should request for each trial.

`LightningTrainer` takes care of environment setup for Distributed Data Parallel training, the model and data will automatically get distributed across GPUs. You only need to set the number of GPUs per worker in `ScalingConfig` and also set `accelerator="gpu"` in LightningTrainerConfigBuilder.

In [29]:
scaling_config = ScalingConfig(
    num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

In [30]:
scaling_config = ScalingConfig(
    num_workers=3, use_gpu=False, resources_per_worker={"CPU": 1}
)

In [31]:
# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
    scaling_config=scaling_config,
    run_config=run_config,
)

### Putting it together

Lastly, we need to create a `Tuner()` object and start Ray Tune with `tuner.fit()`.

The full code looks like this:

In [32]:
tuner = tune.Tuner(
    lightning_trainer,
    param_space={"lightning_config": lightning_config},
    tune_config=tune.TuneConfig(
        metric="ptl/val_accuracy",
        mode="max",
        num_samples=1,
    ),
)
results = tuner.fit()
best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
best_result

0,1
Current time:,2023-03-28 18:09:10
Running for:,00:00:16.87
Memory:,3.9/30.9 GiB

Trial name,# failures,error file
LightningTrainer_45adb_00000,1,"/home/ray/ray_results/LightningTrainer_2023-03-28_18-08-53/LightningTrainer_45adb_00000_0_layer_1_size=32,layer_2_size=64,lr=0.0014_2023-03-28_18-08-53/error.txt"

Trial name,status,loc,...odule_init_config /config/layer_1_size,...odule_init_config /config/layer_2_size,..._config/_module_i nit_config/config/lr
LightningTrainer_45adb_00000,ERROR,10.0.57.221:201387,32,64,0.00140243


(pid=201387)   from pandas import MultiIndex, Int64Index [repeated 3x across cluster]
(RayTrainWorker pid=200311) Missing logger folder: /home/ray/ray_results/LightningTrainer_2023-03-28_18-07-48/LightningTrainer_1f42f_00000_0_layer_1_size=128,layer_2_size=64,lr=0.0005_2023-03-28_18-07-48/rank_2/lightning_logs [repeated 2x across cluster]
(RayTrainWorker pid=201612) 2023-03-28 18:09:03,257	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=3]
(RayTrainWorker pid=201612)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=201612)   from pandas import MultiIndex, Int64Index
(RayTrainWorker pid=201614)   from pandas import MultiIndex, Int64Index


(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/train-images-idx3-ubyte.gz


(RayTrainWorker pid=201612) GPU available: False, used: False
(RayTrainWorker pid=201612) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=201612) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=201612) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=201612) Missing logger folder: /home/ray/ray_results/LightningTrainer_2023-03-28_18-08-53/LightningTrainer_45adb_00000_0_layer_1_size=32,layer_2_size=64,lr=0.0014_2023-03-28_18-08-53/rank_0/lightning_logs
  0%|          | 0/9912422 [00:00<?, ?it/s]


(RayTrainWorker pid=201614) Extracting /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/train-images-idx3-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw


100%|██████████| 9912422/9912422 [00:00<00:00, 93508624.79it/s]


(RayTrainWorker pid=201614) 
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/train-labels-idx1-ubyte.gz
(RayTrainWorker pid=201614) Extracting /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/train-labels-idx1-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw
(RayTrainWorker pid=201614) 
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 32285632.68it/s]
  0%|          | 0/1648877 [00:00<?, ?it/s]


(RayTrainWorker pid=201614) Extracting /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw
(RayTrainWorker pid=201614) 
(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25374017.26it/s]


(RayTrainWorker pid=201614) Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/t10k-labels-idx1-ubyte.gz
(RayTrainWorker pid=201614) Extracting /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ray/prj-yx-lightning/doc/source/tune/examples/MNIST/raw
(RayTrainWorker pid=201614) 


100%|██████████| 4542/4542 [00:00<00:00, 33305120.22it/s]
(RayTrainWorker pid=201612)   | Name     | Type     | Params
(RayTrainWorker pid=201612) --------------------------------------
(RayTrainWorker pid=201612) 0 | accuracy | Accuracy | 0     
(RayTrainWorker pid=201612) 1 | layer_1  | Linear   | 25.1 K
(RayTrainWorker pid=201612) 2 | layer_2  | Linear   | 2.1 K 
(RayTrainWorker pid=201612) 3 | layer_3  | Linear   | 650   
(RayTrainWorker pid=201612) --------------------------------------
(RayTrainWorker pid=201612) 27.9 K    Trainable params
(RayTrainWorker pid=201612) 0         Non-trainable params
(RayTrainWorker pid=201612) 27.9 K    Total params
(RayTrainWorker pid=201612) 0.112     Total estimated model params size (MB)
2023-03-28 18:09:10,291	ERROR trial_runner.py:1450 -- Trial LightningTrainer_45adb_00000: Error happened when processing _ExecutorEventType.TRAINING_RESULT.
ray.exceptions.RayTaskError(AttributeError): ray::_Inner.train() (pid=201387, ip=10.0.57.221, repr=Light

Trial name,date,hostname,node_ip,pid,timestamp,trial_id
LightningTrainer_45adb_00000,2023-03-28_18-09-00,ip-10-0-57-221,10.0.57.221,201387,1680052140,45adb_00000


2023-03-28 18:09:10,313	ERROR tune.py:941 -- Trials did not complete: [LightningTrainer_45adb_00000]
2023-03-28 18:09:10,314	INFO tune.py:945 -- Total run time: 16.89 seconds (16.86 seconds for the tuning loop).


Result(
  error='RayTaskError(AttributeError)',
  metrics={'trial_id': '45adb_00000'},
  path='/home/ray/ray_results/LightningTrainer_2023-03-28_18-08-53/LightningTrainer_45adb_00000_0_layer_1_size=32,layer_2_size=64,lr=0.0014_2023-03-28_18-08-53',
  checkpoint=None
)