(lightning_mnist_example)=

# Train a Pytorch Lightning Image Classifier

This example introduces how to train a Pytorch Lightning Module using AIR {class}`LightningTrainer <ray.train.lightning.LightningTrainer>`. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.


In [None]:
!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6" 

In [None]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger

## Prepare Dataset and Module

The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer. 

In [3]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            # split data into train and val sets
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


datamodule = MNISTDataModule(batch_size=128)


Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`.

In [4]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.eval_loss = []
        self.eval_accuracy = []
        self.test_accuracy = []
        pl.seed_everything(888)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.test_accuracy.append(acc)
        self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

You don't need to make any change to the definition of PyTorch Lightning model and datamodule.

(lightning-config-builder-intro)=

## Define the Cofigurations for AIR LightningTrainer

The {meth}`LightningConfigBuilder <ray.train.lightning.LightningConfigBuilder>` class stores all the parameters involved in training a PyTorch Lightning module. It takes the same parameter lists as those in PyTorch Lightning.

- The `.module()` method takes a subclass of `pl.LightningModule` and its initialization parameters. `LightningTrainer` will instantiate a model instance internally in the workers' training loop.
- The `.trainer()` method takes the initialization parameters of `pl.Trainer`. You can specify training configurations, loggers, and callbacks here.
- The `.fit_params()` method stores all the parameters that will be passed into `pl.Trainer.fit()`, including train/val dataloaders, datamodules, and checkpoint paths.
- The `.checkpointing()` method saves the configurations for a `RayModelCheckpoint` callback. This callback reports the latest metrics to the AIR session along with a newly saved checkpoint.
- The `.build()` method generates a dictionary that contains all the configurations in the builder. This dictionary will be passed to `LightningTrainer` later.

Next, let's go step-by-step to see how to convert your existing PyTorch Lightning training script to a LightningTrainer.

In [5]:
from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
    LightningTrainer,
    LightningConfigBuilder,
    LightningCheckpoint,
)


def build_lightning_config_from_existing_code(use_gpu):
    # Create a config builder to encapsulate all required parameters.
    # Note that model instantiation and fitting will occur later in the LightingTrainer,
    # rather than in the config builder.
    config_builder = LightningConfigBuilder()

    # 1. define your model
    # model = MNISTClassifier(lr=1e-3, feature_dim=128)
    config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)

    # 2. define a ModelCheckpoint callback
    # checkpoint_callback = ModelCheckpoint(
    #     monitor="val_accuracy", mode="max", save_top_k=3
    # )
    config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

    # 3. Define a Lightning trainer
    # trainer = pl.Trainer(
    #     max_epochs=10,
    #     accelerator="cpu",
    #     strategy="ddp",
    #     log_every_n_steps=100,
    #     logger=CSVLogger("logs"),
    #     callbacks=[checkpoint_callback],
    # )
    config_builder.trainer(
        max_epochs=10,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    # You do not need to provide the checkpoint callback and strategy here,
    # since LightningTrainer configures them automatically.
    # You can also add any other callbacks into LightningConfigBuilder.trainer().

    # 4. Parameters for model fitting
    # trainer.fit(model, datamodule=datamodule)
    config_builder.fit_params(datamodule=datamodule)

    # Finally, compile all the configs into a dictionary for LightningTrainer
    lightning_config = config_builder.build()
    return lightning_config

Now put everything together:

In [6]:
use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 4

In [7]:
lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)

scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

run_config = RunConfig(
    name="ptl-mnist-example",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

Now fit your trainer:

In [8]:
result = trainer.fit()
print("Validation Accuracy: ", result.metrics["val_accuracy"])
result

2023-06-13 16:05:12,869	INFO worker.py:1452 -- Connecting to existing Ray cluster at address: 10.0.28.253:6379...
2023-06-13 16:05:12,877	INFO worker.py:1627 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://console.anyscale-staging.com/api/v2/sessions/ses_15dlj65vax84ljl7ayeplubryd/services?redirect_to=dashboard [39m[22m
2023-06-13 16:05:13,036	INFO packaging.py:347 -- Pushing file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip' (52.65MiB) to Ray cluster...
2023-06-13 16:05:13,221	INFO packaging.py:360 -- Successfully pushed file package 'gcs://_ray_pkg_488e346d50f332edaa288fdaa22b2bdc.zip'.
2023-06-13 16:05:13,314	INFO tune.py:226 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Trainer(...)`.


0,1
Current time:,2023-06-13 16:05:52
Running for:,00:00:39.29
Memory:,5.5/30.9 GiB

Trial name,status,loc,iter,total time (s),train_loss,val_accuracy,val_loss
LightningTrainer_c0d28_00000,TERMINATED,10.0.28.253:16995,10,28.5133,0.0315991,0.970002,-12.3467


[2m[36m(pid=16995)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(LightningTrainer pid=16995)[0m 2023-06-13 16:05:24,007	INFO backend_executor.py:137 -- Starting distributed worker processes: ['17232 (10.0.28.253)', '6371 (10.0.1.80)', '7319 (10.0.58.90)', '6493 (10.0.26.229)']
[2m[36m(RayTrainWorker pid=17232)[0m 2023-06-13 16:05:24,966	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
[2m[36m(RayTrainWorker pid=17232)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(RayTrainWorker pid=17232)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(RayTrainWorker pid=7319, ip=10.0.58.90)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(RayTrainWorker pid=17232)[0m Global seed set to 888
[2m[36m(RayTrainWorker pid=17232)[0m GPU available: True, used: True
[2m[36m(RayTrainWorker pid=17232)[0m TPU available: False, using: 0 TPU cores
[2m[36m(RayTrainWorker pid=17232)[0m IPU available: False, using: 0 IPUs
[2m

Sanity Checking: 0it [00:00, ?it/s][0m 
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  1.33it/s]
Epoch 0:   0%|          | 0/118 [00:00<?, ?it/s]                           


[2m[36m(RayTrainWorker pid=6371, ip=10.0.1.80)[0m   from pandas import MultiIndex, Int64Index[32m [repeated 2x across cluster][0m


Epoch 0:   3%|▎         | 4/118 [00:00<00:07, 16.07it/s, loss=2.09, v_num=0]
Epoch 0:   4%|▍         | 5/118 [00:00<00:05, 19.42it/s, loss=2.09, v_num=0]
Epoch 0:  12%|█▏        | 14/118 [00:00<00:02, 39.49it/s, loss=1.55, v_num=0]
Epoch 0:  12%|█▏        | 14/118 [00:00<00:02, 38.73it/s, loss=1.5, v_num=0] 
Epoch 0:  21%|██        | 25/118 [00:00<00:01, 53.89it/s, loss=0.933, v_num=0]
Epoch 0:  30%|██▉       | 35/118 [00:00<00:01, 61.80it/s, loss=0.522, v_num=0]
Epoch 0:  38%|███▊      | 45/118 [00:00<00:01, 67.21it/s, loss=0.425, v_num=0]
Epoch 0:  45%|████▍     | 53/118 [00:00<00:00, 69.59it/s, loss=0.379, v_num=0]
Epoch 0:  46%|████▌     | 54/118 [00:00<00:00, 69.65it/s, loss=0.373, v_num=0]
Epoch 0:  54%|█████▍    | 64/118 [00:00<00:00, 73.24it/s, loss=0.364, v_num=0]
Epoch 0:  62%|██████▏   | 73/118 [00:00<00:00, 74.68it/s, loss=0.341, v_num=0]
Epoch 0:  63%|██████▎   | 74/118 [00:00<00:00, 75.21it/s, loss=0.341, v_num=0]
Epoch 0:  70%|███████   | 83/118 [00:01<00:00, 76.62it/s, 

Trial name,_report_on,date,done,epoch,experiment_tag,hostname,iterations_since_restore,node_ip,pid,should_checkpoint,step,time_since_restore,time_this_iter_s,time_total_s,timestamp,train_loss,training_iteration,trial_id,val_accuracy,val_loss
LightningTrainer_c0d28_00000,train_epoch_end,2023-06-13_16-05-50,True,9,0,ip-10-0-28-253,10,10.0.28.253,16995,True,1080,28.5133,1.73311,28.5133,1686697550,0.0315991,10,c0d28_00000,0.970002,-12.3467


Epoch 1:   0%|          | 0/118 [00:00<?, ?it/s, loss=0.284, v_num=0]          
Epoch 1:   2%|▏         | 2/118 [00:00<00:15,  7.71it/s, loss=0.283, v_num=0]
Epoch 1:  11%|█         | 13/118 [00:00<00:02, 35.75it/s, loss=0.268, v_num=0]
Epoch 1:  20%|██        | 24/118 [00:00<00:01, 51.49it/s, loss=0.253, v_num=0]
Epoch 1:  28%|██▊       | 33/118 [00:00<00:01, 57.86it/s, loss=0.252, v_num=0]
Epoch 1:  36%|███▋      | 43/118 [00:00<00:01, 64.22it/s, loss=0.244, v_num=0]
Epoch 1:  37%|███▋      | 44/118 [00:00<00:01, 64.96it/s, loss=0.244, v_num=0]
Epoch 1:  37%|███▋      | 44/118 [00:00<00:01, 64.66it/s, loss=0.245, v_num=0]
Epoch 1:  46%|████▌     | 54/118 [00:00<00:00, 69.28it/s, loss=0.241, v_num=0]
Epoch 1:  55%|█████▌    | 65/118 [00:00<00:00, 73.79it/s, loss=0.245, v_num=0]
Epoch 1:  64%|██████▎   | 75/118 [00:00<00:00, 75.85it/s, loss=0.22, v_num=0] 
Epoch 1:  64%|██████▎   | 75/118 [00:00<00:00, 75.83it/s, loss=0.222, v_num=0]
Epoch 1:  72%|███████▏  | 85/118 [00:01<00:00, 78.43

2023-06-13 16:05:52,777	INFO tune.py:1111 -- Total run time: 39.46 seconds (39.28 seconds for the tuning loop).


Validation Accuracy:  0.9700015783309937


Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.03159911185503006, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346744537353516, 'epoch': 9, 'step': 1080, 'should_checkpoint': True, 'done': True, 'trial_id': 'c0d28_00000', 'experiment_tag': '0'},
  path='/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13',
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009)
)

## Evaluate your model on test dataset

Next, let's evaluate the model's performance on the MNIST test set. We will first retrieve the best checkpoint from the fitting results and load it into the model.

If you lost the in-memory result object, you can also restore the model from the checkpoint file. Here the checkpoint path is: `/tmp/ray_results/ptl-mnist-example/LightningTrainer_c0d28_00000_0_2023-06-13_16-05-13/checkpoint_000009/model`.

In [None]:
checkpoint: LightningCheckpoint = result.checkpoint
best_model: pl.LightningModule = checkpoint.get_model(MNISTClassifier)

### Single-node Testing

If you have a relatively small test set, like MNIST, the easiest way is to use PyTorch Lightning's native interface to evaluate the best model. Pass the loaded model and test data loader to ``pl.Trainer.test()``, which will execute the test loop using your custom ``pl.LightningModule.test_step()`` method on your head node.

In [11]:
# Download and setup MNIST datamodule on the head node
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

trainer = pl.Trainer()
result = trainer.test(best_model, dataloaders=test_dataloader)

  rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

2023-06-13 16:05:53.932195: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-13 16:05:54.097738: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-13 16:05:55.022170: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-06-13 16:05:55.022249: W tensorflow/

## Multi-node Testing

Alternatively, if you have a large test set and want to speed up the testing process in parallel, you can create a group of Ray Actors to leverage multiple GPUs across multiple nodes for distributed inference. Here we demonstrate how to set up a process group and do evaluation using 4 GPUs.

In [12]:
import ray
import pytorch_lightning as pl

from pytorch_lightning.plugins.environments.lightning_environment import (
    LightningEnvironment,
)
from ray.air.util.torch_dist import (
    TorchDistributedWorker,
    init_torch_dist_process_group,
    shutdown_torch_dist_process_group,
)


class RayEnvironment(LightningEnvironment):
    """Setup Lightning DDP training environment for Ray cluster."""

    def world_size(self) -> int:
        return int(os.environ["WORLD_SIZE"])

    def global_rank(self) -> int:
        return int(os.environ["RANK"])

    def local_rank(self) -> int:
        return int(os.environ["LOCAL_RANK"])

    def set_world_size(self, size: int) -> None:
        # Disable it since `world_size()` directly returns data from AIR session.
        pass

    def set_global_rank(self, rank: int) -> None:
        # Disable it since `global_rank()` directly returns data from AIR session.
        pass

    def teardown(self):
        pass


@ray.remote
class TestWorker(TorchDistributedWorker):
    def run(self):
        trainer = pl.Trainer(
            num_nodes=num_workers,
            accelerator="gpu",
            strategy="ddp",
            plugins=[RayEnvironment()],
        )
        return trainer.test(best_model, dataloaders=test_dataloader)


# Create 4 remote Ray Actors, each with 1 GPU
workers = [TestWorker.options(num_gpus=1).remote() for _ in range(num_workers)]

# Initialize the Torch distributed group among the 4 actors.
# This will set up the required environment variables including 
# RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDRESS, ...
init_torch_dist_process_group(workers=workers, backend="nccl")

# Execute the testing run in parallel
results = ray.get([worker.run.remote() for worker in workers])

# Shutdown the process group
shutdown_torch_dist_process_group(workers=workers)


[2m[36m(RayTrainWorker pid=17232)[0m Global seed set to 888[32m [repeated 7x across cluster][0m
[2m[36m(RayTrainWorker pid=7319, ip=10.0.58.90)[0m Missing logger folder: logs/lightning_logs[32m [repeated 3x across cluster][0m
[2m[36m(RayTrainWorker pid=7319, ip=10.0.58.90)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0][32m [repeated 3x across cluster][0m
[2m[36m(pid=9162, ip=10.0.26.229)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=9162, ip=10.0.26.229)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=9976, ip=10.0.58.90)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(TestWorker pid=20600)[0m   rank_zero_warn(
[2m[36m(TestWorker pid=20600)[0m GPU available: True, used: True
[2m[36m(TestWorker pid=20600)[0m TPU available: False, using: 0 TPU cores
[2m[36m(TestWorker pid=20600)[0m IPU available: False, using: 0 IPUs
[2m[36m(TestWorker pid=20600)[0m HPU available: False, using: 0 HPUs
[2m[36m(TestWorker pid=20600)[0m 

Testing: 0it [00:00, ?it/s]600)[0m 
Testing DataLoader 0:   0%|          | 0/20 [00:00<?, ?it/s]
Testing DataLoader 0:  10%|█         | 2/20 [00:00<00:13,  1.36it/s]
Testing DataLoader 0:  75%|███████▌  | 15/20 [00:00<00:00, 22.10it/s]


[2m[36m(TestWorker pid=20600)[0m 2023-06-13 16:06:07.550225: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
[2m[36m(TestWorker pid=20600)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(TestWorker pid=9976, ip=10.0.58.90)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0][32m [repeated 4x across cluster][0m
[2m[36m(pid=20600)[0m   from pandas import MultiIndex, Int64Index[32m [repeated 2x across cluster][0m
[2m[36m(TestWorker pid=20600)[0m 2023-06-13 16:06:07.708119: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.

Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 22.10it/s]


[2m[36m(TestWorker pid=20600)[0m 2023-06-13 16:06:08.680418: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
[2m[36m(TestWorker pid=20600)[0m 2023-06-13 16:06:08.680524: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64


Testing DataLoader 0: 100%|██████████| 20/20 [00:02<00:00,  7.31it/s]
[2m[36m(TestWorker pid=20600)[0m ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
[2m[36m(TestWorker pid=20600)[0m ┃        Test metric        ┃       DataLoader 0        ┃
[2m[36m(TestWorker pid=20600)[0m ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
[2m[36m(TestWorker pid=20600)[0m │       test_accuracy       │    0.9740999937057495     │
[2m[36m(TestWorker pid=20600)[0m └───────────────────────────┴───────────────────────────┘


## What's next?

- {ref}`Use LightningTrainer with Ray Data and Batch Predictor <lightning_advanced_example>`
- {ref}`Fine-tune a Large Language Model with LightningTrainer and FSDP <dolly_lightning_fsdp_finetuning>`
- {ref}`Hyperparameter searching with LightningTrainer + Ray Tune. <tune-pytorch-lightning-ref>`
- {ref}`Experiment Tracking with Wandb, CometML, MLFlow, and Tensorboard in LightningTrainer <lightning_experiment_tracking>`