## 3. Distributed Training with Ray Train and PyTorch Lightning

Let's consider the case where we have a very large dataset of images that would take a long time to train on a single GPU. We would now like to scale this training job to run on multiple GPUs.

### 3.1 Distributed Data Parallel Training
Here is a diagram showing the standard distributed data parallel training loop.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_v4.png" width="900px" loading="lazy">|
|:--|
|Schematic overview of DistributedDataParallel (DDP) training: (1) the model is replicated from the <code>GPU rank 0</code> to all other workers; (2) each worker receives a shard of the dataset and processes a mini-batch; (3) during the backward pass, gradients are averaged across GPUs; (4) checkpoint and metrics from rank 0 GPU are saved to the persistent storage.|

### 3.2 Ray Train Migration

Here are the changes we need to make to the training loop to migrate it to Ray Train.

In [None]:
def train_loop_per_worker(
    config: dict,  # Update the function signature to comply with Ray Train
):
    # Note lightning prepares the dataloader (adding a distributed sampler) for us
    train_dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size_per_worker"],
        shuffle=True,
        num_workers=2,
    )

    # Same model initialization as vanilla lightning
    model = StableDiffusion(
        lr=config["lr"],
        resolution=config["resolution"],
        weight_decay=config["weight_decay"],
        num_warmup_steps=config["num_warmup_steps"],
        model_name=config["model_name"],
    )

    # Same trainer setup as vanilla lightning except we add Ray Train specific arguments
    trainer = pl.Trainer(
        max_steps=config["max_steps"],
        max_epochs=config["max_epochs"],
        accelerator="gpu",
        precision="bf16-mixed",
        devices="auto",  # Set devices to "auto" to use all available GPUs
        strategy=RayDDPStrategy(),  # Use RayDDPStrategy for distributed data parallel training
        plugins=[
            RayLightningEnvironment()
        ],  # Use RayLightningEnvironment to run the Lightning Trainer
        callbacks=[
            RayTrainReportCallback()
        ],  # Use RayTrainReportCallback to report metrics and checkpoints
        enable_checkpointing=False,  # Disable lightning checkpointing
    )

    # 4. Same as vanilla lightning
    trainer.fit(model, train_dataloaders=train_dataloader)

Here is the same diagram as before but with the Ray Train specific components highlighted.

<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_lightning_annotated_no_data_v2.png" width=900 loading="lazy">

We made use of:
- `ray.train.get_dataset_shard("train")` to get the training dataset shard.
- `RayDDPStrategy` to perform distributed data parallel training.
- `RayLightningEnvironment` to run the Lightning Trainer.
- `RayTrainReportCallback` to report metrics and checkpoints.

### 3.3. Configure scale and GPUs
Outside of our training function, we create a `ScalingConfig`.

In [None]:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

<div class="alert alert-block alert-info">

<a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray-train-scalingconfig" target="_blank">ScalingConfig</a> configures:

<ul>
  <li><code>num_workers</code>: The number of distributed training worker processes.</li>
  <li><code>use_gpu</code>: Whether each worker should use a GPU (or CPU).</li>
</ul>

See docs on configuring <a href="https://docs.ray.io/en/latest/train/user-guides/using-gpus.html" target="_blank">scale and GPUs</a> for more details.
</div>

#### 3.3.1. Note on Ray Train key concepts

Ray Train is built around [four key concepts](https://docs.ray.io/en/latest/train/overview.html):
1. **Training function**: (implemented above `train_loop_ray_train`): A Python function that contains your model training logic.
1. **Worker**: A process that runs the training function.
1. **Scaling config**: specifies number of workers and compute resources (CPUs or GPUs, TPUs).
1. **Trainer**: A Python class (Ray Actor) that ties together the training function, workers, and scaling configuration to execute a distributed training job.

|<img src="https://docs.ray.io/en/latest/_images/overview.png" width="700px" loading="lazy">|
|:--|
|High-level architecture of how Ray Train|

### 3.4 Create and fit a Ray Train TorchTrainer

We first specify the run configuration to tell Ray Train where to store the checkpoints and metrics

In [None]:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining"

run_config = ray.train.RunConfig(name=experiment_name, storage_path=storage_path)

Now we can create our Ray Train `TorchTrainer`

In [None]:
train_loop_config = {
    "batch_size_per_worker": 8,
    "prefetch_batches": 1,
    "lr": 0.0001,
    "num_warmup_steps": 10_000,
    "weight_decay": 0.01,
    "max_steps": 550_000,
    "max_epochs": 1,
    "resolution": 256,
    "model_name": "stabilityai/stable-diffusion-2-base",
}

trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

We call `.fit()` to start the training job.

In [None]:
result = trainer.fit()
result

### 3.5. Access the training results

We can check the metrics produced by the training job.

In [None]:
result.metrics_dataframe

### 3.6. Load the checkpointed model to generate predictions

In [None]:
ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    ckpt_path = os.path.join(ckpt_dir, "checkpoint.ckpt")
    loaded_model_ray_train = StableDiffusion.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        map_location=torch.device("cpu"),
        weights_only=True,
    )
    loaded_model_ray_train.eval()

loaded_model_ray_train

### 3.7. Activity: Run the distributed training with more workers

<div class="alert alert-block alert-info">

1. Update the scaling configuration to make use of 4 GPU workers
2. Run the trainer using the same hyperparameters

Use the following code snippets to guide you:

```python
# Hint: Update the scaling configuration
scaling_config = ...

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)
result = trainer.fit()
result.metrics_dataframe
```

</div>

In [None]:
# Write your solution here


<div class="alert alert-block alert-info">

<details>

<summary> Click here to see the solution </summary>

```python
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)
result = trainer.fit()
result.metrics_dataframe
```