# Introduction to Ray Train: Ray Train + PyTorch Lightning

© 2025, Anyscale. All Rights Reserved

💻 **Launch Locally**: Not recommended. You may encounter out-of-memory (OOM) errors when running certain cells locally.

🚀 **Cloud Required**: A Ray Cluster with 4 GPUs (Click [here](http://console.anyscale.com/register) to easily start a Ray cluster on Anyscale) is recommended to run this notebook.

This notebook demonstrates how to train a Stable Diffusion model using PyTorch Lightning and Ray Train. 

<div class="alert alert-block alert-info">

<b>Here is the roadmap for this notebook:</b>

<ol>
    <li>When to use Ray Train</li>
    <li>Single GPU Training with PyTorch Lightning</li>
    <li>Distributed Training with Ray Train and PyTorch Lightning</li>
    <li>Ray Train in Production</li>
</ol>

</div>

**Imports**

In [None]:
import os

import lightning.pytorch as pl
import numpy as np
import pandas as pd
import s3fs
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from diffusers import DDPMScheduler, UNet2DConditionModel
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from transformers import get_linear_schedule_with_warmup

import ray.train
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
)
from ray.train.torch import TorchTrainer, get_device

## 1. When to use Ray Train


Use Ray Train when you face one of the following challenges:

|Challenge|Detail|Solution|
|---|---|---|
|**Need to speed up or scale up training**| Training jobs might take a long time to complete, or require a lot of compute | Ray Train provides a **distributed training** framework that allows engineers to scale training jobs to multiple GPUs |
|**Minimize overhead of setting up clusters**| Engineers need to manage the underlying infrastructure | Ray Train **provisions the underlying infrastructure** via Ray's cluster autoscaler. |
|**Achieve observability**| Engineers need to connect to different nodes and GPUs to find the root cause of failures, fetch logs, traces, etc | Ray Train **provides observability** via Ray's dashboard, metrics, and traces that allow engineers to monitor the training job |
|**Ensure reliable training**| Training jobs can fail due to hardware failures, network issues, or other unexpected events | Ray Train **ensures fault tolerance** via checkpointing, automatic retries, and the ability to resume training from the last checkpoint |
|**Avoid significant code rewrite**| Engineers might need to fully rewrite their training loop to support distributed training | Ray Train has **built-in integrations** with the PyTorch ecosystem (Torch, Lightning, Huggingface), Tree-based methods (XGB, LGBM), and more to minimize the amount of code changes needed |


## 2. Single GPU Training with PyTorch Lightning

### 2.1 Overview

We will start by fitting a Stable Diffusion `Unet` model to a preprocessed image and text dataset.

This diagram shows the full-scale training architecture.

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/stable-diffusion/training_architecture_v3.jpeg" width="700px">

Regardless of scale, the process is primarily composed of three main stages:
1. **Loading the preprocessed data**
2. **Training the model**
3. **Storing the model checkpoints**


### 2.2. Create a torch dataloader 

Let's start by defining a dataset we want to use. 

We'll use `parquet` data that was generated using the same preprocessing pipeline.

In [None]:
class ParquetDataset(Dataset):
    """Minimal PyTorch Dataset for data stored in parquet."""

    def __init__(
        self,
        s3_path: str,
        dtype: np.dtype = np.float16,
    ):
        self.columns = ["image_latents_256", "caption_latents"]
        loaded_df = pd.read_parquet(
            s3_path,
            columns=["image_latents_256", "caption_latents"],
            filesystem=s3fs.S3FileSystem(anon=False),
        )
        loaded_df["image_latents_256"] = loaded_df["image_latents_256"].apply(
            lambda x: x.reshape(4, 32, 32).astype(dtype)
        )
        loaded_df["caption_latents"] = loaded_df["caption_latents"].apply(
            lambda x: x.reshape(77, 1024).astype(dtype)
        )
        self.df = loaded_df

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        # Convert to tensors so default collate stacks them automatically.
        return {col: torch.as_tensor(row[col]) for col in self.columns}

<div class="alert alert-block alert-warning">

**Note** this Dataset implementation is very simple and loads the entire dataset into memory. Not recommended for large datasets.

</div>

Let's proceed to build our custom map-style dataset.

In [None]:
artifact_storage_path = os.environ["ANYSCALE_ARTIFACT_STORAGE"]
s3_path = f"{artifact_storage_path}/stable-diffusion/256/6_000108_000000.parquet"
dataset = ParquetDataset(s3_path)

We construct a torch dataloader that will be used to train the model.

In [None]:
data_loader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on memory constraints
    shuffle=True,
    num_workers=2,  # Adjust based on system's CPU cores
)

We can inspect the batches to verify their shape and type.

In [None]:
for batch in data_loader:
    print(batch["image_latents_256"].shape, batch["caption_latents"].shape)
    print(batch["caption_latents"].dtype, batch["image_latents_256"].dtype)
    break  # Remove this to process the entire dataset

### 2.3 Define a stable diffusion model

This "standard" LightningModule does not explicitly refer to Ray or Ray Train, which makes migrating workloads easier.

In [None]:
small_unet_model_config = {
    "_class_name": "UNet2DConditionModel",
    "_diffusers_version": "0.2.2",
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [160, 320, 640, 640],
    "center_input_sample": False,
    "cross_attention_dim": 1024,
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "in_channels": 4,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "out_channels": 4,
    "sample_size": 64,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ],
}

class StableDiffusion(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        resolution: int,
        weight_decay: float,
        num_warmup_steps: int,
        model_name: str,
    ) -> None:
        self.lr = lr
        self.resolution = resolution
        self.weight_decay = weight_decay
        self.num_warmup_steps = num_warmup_steps
        super().__init__()
        self.save_hyperparameters()
        # Initialize U-Net.
        # model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")[0]
        model_config = small_unet_model_config
        self.unet = UNet2DConditionModel(**model_config)
        # Define the training noise scheduler.
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_name, subfolder="scheduler"
        )
        # Setup loss function.
        self.loss_fn = F.mse_loss
        self.current_training_steps = 0

    def on_fit_start(self) -> None:
        """Move cumprod tensor to GPU in advance to avoid data movement on each step."""
        self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(
            get_device()
        )

    def forward(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model."""
        # Extract inputs.
        latents = batch["image_latents_256"]
        conditioning = batch["caption_latents"]
        # Sample the diffusion timesteps.
        timesteps = self._sample_timesteps(latents)
        # Add noise to the inputs (forward diffusion).
        noise = torch.randn_like(latents)
        noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        # Forward through the model.
        outputs = self.unet(noised_latents, timesteps, conditioning)["sample"]
        return outputs, noise

    def training_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step of the model."""
        outputs, targets = self.forward(batch)
        loss = self.loss_fn(outputs, targets)
        self.log(
            "train/loss_mse", loss.item(), prog_bar=False, on_step=True, sync_dist=False
        )
        self.current_training_steps += 1
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure the optimizer and learning rate scheduler."""
        optimizer = torch.optim.AdamW(
            self.trainer.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )
        # Set a large training step here to keep lr constant after warm-up.
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=100000000000,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

    def _sample_timesteps(self, latents: torch.Tensor) -> torch.Tensor:
        return torch.randint(
            0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device
        )

### 2.4. Define a PyTorch Lightning training loop

Here is a training loop that is specific to PyTorch Lightning.

It performs the following steps:
1. **Model Initialization:**
   - Instantiate the diffusion model.
2. **Trainer Setup:**
   - Instantiate the Lightning Trainer with a `DDPStrategy` to perform data parallel training.
3. **Training Execution:**
   - Run the trainer using the `fit` method.

In [None]:
def lightning_training_loop(
    train_loader: torch.utils.data.DataLoader,
    storage_path: str,
    model_name: str = "stabilityai/stable-diffusion-2-base",
    resolution: int = 256,
    lr: float = 1e-4,
    max_epochs: int = 1,
    num_warmup_steps: int = 10_000,
    weight_decay: float = 1e-2,
) -> None:
    # 1. Initialize the model
    model = StableDiffusion(
        model_name=model_name,
        resolution=resolution,
        lr=lr,
        num_warmup_steps=num_warmup_steps,
        weight_decay=weight_decay,
    )

    # 2. Initialize the Lightning Trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        precision="bf16-mixed",
        max_epochs=max_epochs,
        default_root_dir=storage_path,
        log_every_n_steps=8,
    )

    # 3. Run the trainer
    trainer.fit(model=model, train_dataloaders=train_loader)


Here is how we would run the lightning training loop on a single GPU.

In [None]:
storage_path = "/mnt/local_storage/lightning/stable-diffusion-pretraining/"
lightning_training_loop(train_loader=data_loader, storage_path=storage_path)

Let's inspect the storage path to see what files were created.

In [None]:
!tree {storage_path}

## 3. Distributed Training with Ray Train and PyTorch Lightning

Let's consider the case where we have a very large dataset of images that would take a long time to train on a single GPU. We would now like to scale this training job to run on multiple GPUs.

### 3.1 Distributed Data Parallel Training
Here is a diagram showing the standard distributed data parallel training loop.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_v4.png" width="900px" loading="lazy">|
|:--|
|Schematic overview of DistributedDataParallel (DDP) training: (1) the model is replicated from the <code>GPU rank 0</code> to all other workers; (2) each worker receives a shard of the dataset and processes a mini-batch; (3) during the backward pass, gradients are averaged across GPUs; (4) checkpoint and metrics from rank 0 GPU are saved to the persistent storage.|

### 3.2 Ray Train Migration

Here are the changes we need to make to the training loop to migrate it to Ray Train.

In [None]:
def train_loop_per_worker(
    config: dict,  # Update the function signature to comply with Ray Train
):
    # Note lightning prepares the dataloader (adding a distributed sampler) for us
    train_dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size_per_worker"],
        shuffle=True,
        num_workers=2,
    )

    # Same model initialization as vanilla lightning
    model = StableDiffusion(
        lr=config["lr"],
        resolution=config["resolution"],
        weight_decay=config["weight_decay"],
        num_warmup_steps=config["num_warmup_steps"],
        model_name=config["model_name"],
    )

    # Same trainer setup as vanilla lightning except we add Ray Train specific arguments
    trainer = pl.Trainer(
        max_steps=config["max_steps"],
        max_epochs=config["max_epochs"],
        accelerator="gpu",
        precision="bf16-mixed",
        devices="auto",  # Set devices to "auto" to use all available GPUs
        strategy=RayDDPStrategy(),  # Use RayDDPStrategy for distributed data parallel training
        plugins=[
            RayLightningEnvironment()
        ],  # Use RayLightningEnvironment to run the Lightning Trainer
        callbacks=[
            RayTrainReportCallback()
        ],  # Use RayTrainReportCallback to report metrics and checkpoints
        enable_checkpointing=False,  # Disable lightning checkpointing
    )

    # 4. Same as vanilla lightning
    trainer.fit(model, train_dataloaders=train_dataloader)

Here is the same diagram as before but with the Ray Train specific components highlighted.

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_lightning_annotated_no_data_v2.png" width=900 loading="lazy">

We made use of:
- `ray.train.get_dataset_shard("train")` to get the training dataset shard.
- `RayDDPStrategy` to perform distributed data parallel training.
- `RayLightningEnvironment` to run the Lightning Trainer.
- `RayTrainReportCallback` to report metrics and checkpoints.

### 3.3. Configure scale and GPUs
Outside of our training function, we create a `ScalingConfig`.

In [None]:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

<div class="alert alert-block alert-info">

<a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray-train-scalingconfig" target="_blank">ScalingConfig</a> configures:

<ul>
  <li><code>num_workers</code>: The number of distributed training worker processes.</li>
  <li><code>use_gpu</code>: Whether each worker should use a GPU (or CPU).</li>
</ul>

See docs on configuring <a href="https://docs.ray.io/en/latest/train/user-guides/using-gpus.html" target="_blank">scale and GPUs</a> for more details.
</div>

#### 3.3.1. Note on Ray Train key concepts

Ray Train is built around [four key concepts](https://docs.ray.io/en/latest/train/overview.html):
1. **Training function**: (implemented above `train_loop_ray_train`): A Python function that contains your model training logic.
1. **Worker**: A process that runs the training function.
1. **Scaling config**: specifies number of workers and compute resources (CPUs or GPUs, TPUs).
1. **Trainer**: A Python class (Ray Actor) that ties together the training function, workers, and scaling configuration to execute a distributed training job.

|<img src="https://docs.ray.io/en/latest/_images/overview.png" width="700px" loading="lazy">|
|:--|
|High-level architecture of how Ray Train|

### 3.4 Create and fit a Ray Train TorchTrainer

We first specify the run configuration to tell Ray Train where to store the checkpoints and metrics

In [None]:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining"

run_config = ray.train.RunConfig(name=experiment_name, storage_path=storage_path)

Now we can create our Ray Train `TorchTrainer`

In [None]:
train_loop_config = {
    "batch_size_per_worker": 8,
    "prefetch_batches": 1,
    "lr": 0.0001,
    "num_warmup_steps": 10_000,
    "weight_decay": 0.01,
    "max_steps": 550_000,
    "max_epochs": 1,
    "resolution": 256,
    "model_name": "stabilityai/stable-diffusion-2-base",
}

trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

We call `.fit()` to start the training job.

In [None]:
result = trainer.fit()
result

### 3.5. Access the training results

We can check the metrics produced by the training job.

In [None]:
result.metrics_dataframe

### 3.6. Load the checkpointed model to generate predictions

In [None]:
ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    ckpt_path = os.path.join(ckpt_dir, "checkpoint.ckpt")
    loaded_model_ray_train = StableDiffusion.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        map_location=torch.device("cpu"),
        weights_only=True,
    )
    loaded_model_ray_train.eval()

loaded_model_ray_train

### 3.7. Activity: Run the distributed training with more workers

<div class="alert alert-block alert-info">

1. Update the scaling configuration to make use of 4 GPU workers
2. Run the trainer using the same hyperparameters

Use the following code snippets to guide you:

```python
# Hint: Update the scaling configuration
scaling_config = ...

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)
result = trainer.fit()
result.metrics_dataframe
```

</div>

In [None]:
# Write your solution here


<div class="alert alert-block alert-info">

<details>

<summary> Click here to see the solution </summary>

```python
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)
result = trainer.fit()
result.metrics_dataframe
```

## 4. Ray Train in Production

Here are some use-cases of using Ray Train in production:
1. Canva uses Ray Train + Ray Data to cut down Stable Diffusion training costs by 3.7x. Read this [Anyscale blog post here](https://www.anyscale.com/blog/scalable-and-cost-efficient-stable-diffusion-pre-training-with-ray) and the [Canva  case study here](https://www.anyscale.com/resources/case-study/how-canva-built-a-modern-ai-platform-using-anyscale)
2. Anyscale uses Ray Train + Deepspeed to finetune language models. Read more [here](https://github.com/ray-project/ray/tree/master/doc/source/templates/04_finetuning_llms_with_deepspeed).


In [None]:
# Run this cell for file cleanup
!rm -rf /mnt/cluster_storage/stable-diffusion-pretraining