## 3. Distributed Data Parallel Training with Ray Train and PyTorch

Let's consider the case where we have a very large dataset of images that would take a long time to train on a single GPU. We would now like to scale this training job to run on multiple GPUs.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_v4.png" width="900px" loading="lazy">|
|:--|
|Schematic overview of DistributedDataParallel (DDP) training: (1) the model is replicated from the <code>GPU rank 0</code> to all other workers; (2) each worker receives a shard of the dataset and processes a mini-batch; (3) during the backward pass, gradients are averaged across GPUs; (4) checkpoint and metrics from rank 0 GPU are saved to the persistent storage.|

<div class="alert alert-block alert-info">
<b>Here is a migration roadmap: from PyTorch DDP to PyTorch with Ray Train</b>

<ol>
    <li>Configure scale and GPUs</li>
    <li>Migrate the model to Ray Train</li>
    <li>Migrate the dataset to Ray Train</li>
    <li>Build checkpoints and metrics reporting</li>
    <li>Configure persistent storage</li>
</ol>
</div>

### 3.1. Overview of the training loop in Ray Train

Let's see how this data-parallel training loop will look like with Ray Train and PyTorch.

In [None]:
def train_loop_ray_train(config: dict):  # pass in hyperparameters in config

    criterion = CrossEntropyLoss()

    # Use Ray Train to wrap the model with DistributedDataParallel
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-5)

    # Calculate the batch size for each worker
    global_batch_size = config["global_batch_size"]
    world_size = ray.train.get_context().get_world_size()
    batch_size = global_batch_size // world_size
    print(f"{world_size=}\n{batch_size=}")

    # Use Ray Train to wrap the data loader as a DistributedSampler
    data_loader = build_data_loader_ray_train(batch_size=batch_size)

    # Main training loop
    for epoch in range(config["num_epochs"]):

        # Ensure data is on the correct device
        data_loader.sampler.set_epoch(epoch)

        # images, labels are now sharded across the workers
        for images, labels in data_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()

            # gradients are now accumulated across the workers
            loss.backward()
            optimizer.step()

        # Use Ray Train to report metrics
        metrics = print_metrics_ray_train(loss, epoch)

        # Use Ray Train to save checkpoint and metrics
        save_checkpoint_and_metrics_ray_train(model, metrics)

<div class="alert alert-block alert-info">

<b>Main training loop</b>
<ul>
  <li><strong>global_batch_size</strong>: the total number of samples processed in a single training step of the entire training job.
    <ul>
      <li>It's estimated like this: <code>batch size * DDP workers * gradient accumulation steps</code>.</li>
    </ul>
  </li>
  <li>Notice that images and labels are no longer manually moved to device (<code>images.to("cuda")</code>). This is done by 
    <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader" target="_blank">
      prepare_data_loader()
    </a>.
  </li>
  <li>Config that will be passed here, is defined below. It will be passed to the Ray Train's <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray-train-torch-torchtrainer" target="_blank">TorchTrainer</a>.</li>
  <li>
    <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.v2.api.context.TrainContext.html#ray-train-v2-api-context-traincontext" target="_blank">
      TrainContext
    </a> lets users get useful information about the training i.e. node rank, world size, world rank, experiment name.
  </li>

  <li><code>load_model_ray_train</code> and <code>build_data_loader_ray_train</code> are implemented below.</li>
</ul>
</div>

In [None]:
train_loop_config = {
    "num_epochs": 2, 
    "global_batch_size": 128
}

### 3.2. Configure scale and GPUs
Outside of our training function, we create a `ScalingConfig`.

In [None]:
scaling_config = ScalingConfig(num_workers=2, use_gpu=False) # Set it to False if you don't have GPU

<div class="alert alert-block alert-info">

<a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray-train-scalingconfig" target="_blank">ScalingConfig</a> configures:

<ul>
  <li><code>num_workers</code>: The number of distributed training worker processes.</li>
  <li><code>use_gpu</code>: Whether each worker should use a GPU (or CPU).</li>
</ul>

See docs on configuring <a href="https://docs.ray.io/en/latest/train/user-guides/using-gpus.html" target="_blank">scale and GPUs</a> for more details.
</div>

#### 3.2.1. Note on Ray Train key concepts

Ray Train is built around [four key concepts](https://docs.ray.io/en/latest/train/overview.html):
1. **Training function**: (implemented above `train_loop_ray_train`): A Python function that contains your model training logic.
1. **Worker**: A process that runs the training function.
1. **Scaling config**: specifies number of workers and compute resources (CPUs or GPUs, TPUs).
1. **Trainer**: A Python class (Ray Actor) that ties together the training function, workers, and scaling configuration to execute a distributed training job.

|<img src="https://docs.ray.io/en/latest/_images/overview.png" width="700px" loading="lazy">|
|:--|
|High-level architecture of how Ray Train|

### 3.3. Migrating the model to Ray Train

Use the [`prepare_model()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model) utility function to:

* automatically move your model to the correct device,
* wrap the model in PyTorch's DDP or FSDP.

In [None]:
def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    model = ray.train.torch.prepare_model(model) # Instead of model = model.to("cuda")
    return model

<div class="alert alert-block alert-info">
  <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model" target="_blank">
    prepare_model()
  </a> allows users to specify additional parameters:
  <ul>
    <li><code>parallel_strategy</code>: "ddp", "fsdp" – wrap models in <code>DistributedDataParallel</code> or <code>FullyShardedDataParallel</code></li>
    <li><code>parallel_strategy_kwargs</code>: pass additional arguments to "ddp" or "fsdp"</li>
  </ul>
  <p>
    With <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model" target="_blank">
      prepare_model()
    </a> you can use the same code regardless of number of workers or the device type being used (CPU, GPU).
  </p>
</div>

### 3.4. Migrating the dataset to Ray Train

Use the [`prepare_data_loader()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader) utility function, to automatically:

* move the batches to the right device,
* copy data from host (CPU) memory to device (GPU) memory,
* pass PyTorch's [`DistributedSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) to the DataLoader, if using more than 1 worker. Each worker will load a subset of the original dataset that is exclusive to it.

[`prepare_data_loader()`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader) allows users to use the same code regardless of number of workers or the device type being used (CPU, GPU).

In [None]:
def build_data_loader_ray_train(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

    # Automatically pass a DistributedSampler instance as a DataLoader sampler
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    return train_loader

<div class="alert alert-block alert-warning">

<b>Ray Data integration</b>

This step isn't necessary if you are integrating your Ray Train workload with Ray Data. It's especially useful if preprocessing is CPU-heavy and user wants to run preprocessing and training of separate instances.
</div>

### 3.5. Reporting checkpoints and metrics

To monitor progress, we can continue to print/log metrics as before. This time we chose to log from all workers.

In [None]:
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    world_rank = ray.train.get_context().get_world_rank() # report from all workers
    print(f"{metrics=} {world_rank=}")
    return metrics

<div class="alert alert-block alert-info">

If you want to log only from the rank 0 worker, use this code:

```python
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss.item(), "epoch": epoch}
    if ray.train.get_context().get_world_rank() == 0:  # report only from the rank 0 worker
        print(f"{metrics=} {world_rank=}")
    return metrics
```

</div>

We will report intermediate metrics and checkpoints using the [`ray.train.report`](https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report) utility function.

In [None]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.module.state_dict(),  # note the `.module` to unwrap the DistributedDataParallel
            os.path.join(temp_checkpoint_dir, "model.pt"),
        )

        ray.train.report(
            metrics,
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

<div class="alert alert-block alert-info">
  <p><strong>Quick notes:</strong></p>
  <ul>
    <li>
      Use 
      <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report" target="_blank">
        ray.train.report
      </a> to save the metrics and checkpoint.
    </li>
    <li>Only metrics from the rank 0 worker are reported.</li>
  </ul>
</div>

#### 3.5.1. Note on the checkpoint lifecycle

Here is the lifecycle of a checkpoint from being created using a local path to being uploaded to persistent storage.

<img src="https://docs.ray.io/en/latest/_images/checkpoint_lifecycle.png" width=800>


<div class="alert alert-block alert-info">
  <p><strong>Notes:</strong></p>
  <ul>
    <li>
      Given it is the same model across all workers, we can instead only build the checkpoint on the worker of rank 0.
      Note that we will still need to call 
      <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html#ray.train.report" target="_blank">
        ray.train.report
      </a> on all workers to ensure that the training loop is synchronized.
    </li>
    <li>Ray Train expects all workers to be able to write files to the same persistent storage location.</li>
    <li>Cloud storage is the recommended persistent storage location.</li>
  </ul>
</div>

In [None]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None

        # checkpoint only from rank 0 worker
        if ray.train.get_context().get_world_rank() == 0:
            torch.save(
                model.module.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt")
            )
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        ray.train.report(
            metrics,
            checkpoint=checkpoint,
        )

Check our guide on [saving and loading checkpoints](https://docs.ray.io/en/latest/train/user-guides/checkpoints.html) for more details and best practices.

You can also configure the [experiment tracking libraries](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html) to work out-of-the-box with Ray Train. 

### 3.6. Configure remote storage

Create a `RunConfig` object to specify the path where results (including checkpoints and artifacts) will be saved.

In [None]:
# If local path is used, add the file:// prefix to the storage path, if relative path is used make sure it is absolute by resolving it
#storage_path = f"file://{Path(storage_folder).resolve()}/training/" # Use local path if it runs on local
storage_path = f"{storage_folder}/training/" 
run_config = RunConfig(storage_path=storage_path, name="distributed-mnist-resnet18")

### 3.7. Launching the distributed training job

Distributed data-parallel training, but now using Ray Train.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/multi_gpu_pytorch_annotated_v5.png" width="70%" loading="lazy">|
|:--|
||

We can now launch a distributed training job with a [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer).

In [None]:
trainer = TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config=train_loop_config,
)

Calling `trainer.fit()` will start the run and block until it completes.

We'll be able to observe relevant logs

|<img src="https://assets-training.s3.us-west-2.amazonaws.com/ray-intro/ray-train-intro-logs.png" width="80%" loading="lazy">|
|:--|
||

In [None]:
result = trainer.fit()

### 3.8. Access the training results

After training completes, a `Result` object is returned which contains information about the training run, including the metrics and checkpoints reported during training.

In [None]:
result

In [None]:
!ls {storage_folder}/training/distributed-mnist-resnet18/

We can check the metrics produced by the training job.

In [None]:
result.metrics_dataframe

### 3.9. Use checkpointed model to generate predictions

We can also take the latest checkpoint and load it to inspect the model.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    model_path = os.path.join(ckpt_dir, "model.pt")
    loaded_model_ray_train = build_resnet18()
    state_dict = torch.load(model_path, map_location=torch.device(device), weights_only=True)
    loaded_model_ray_train.load_state_dict(state_dict)
    loaded_model_ray_train.to(device)
    loaded_model_ray_train.eval()

loaded_model_ray_train

<div class="alert alert-block alert-info">
  <p>
    To learn more about the training results, see this 
    <a href="https://docs.ray.io/en/latest/train/user-guides/results.html" target="_blank">
      docs
    </a> on inspecting the training results.
  </p>
</div>

Generate predictions on randomly selected 9 images from the MNIST dataset.

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    normalized_img = Normalize((0.5,), (0.5,))(ToTensor()(img))
    normalized_img = normalized_img.to(device)

    # use loaded model to generate preds
    with torch.no_grad():        
        prediction = loaded_model_ray_train(normalized_img.unsqueeze(0)).argmax().cpu()

    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(prediction)}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")

### 3.10. Activity: Run the distributed training with more workers

<div class="alert alert-block alert-info">

1. Update the scaling configuration to make use of 4 GPU workers
2. Run the trainer using the same hyperparameters

Use the following code snippets to guide you:

```python
# Hint: Update the scaling configuration
scaling_config = ...

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config={"num_epochs": 2, "global_batch_size": 128},
)
result = trainer.fit()
result.metrics_dataframe
```

</div>

In [None]:
# Write your solution here


<div class="alert alert-block alert-info">

<details>

<summary> Click here to see the solution </summary>

```python
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

trainer = ray.train.torch.TorchTrainer(
    train_loop_ray_train,
    scaling_config=scaling_config,
    run_config=run_config,
    train_loop_config={"num_epochs": 2, "global_batch_size": 128},
)
result = trainer.fit()
result.metrics_dataframe
```

</details>

### 4. Ray Train in Production

Here are some use-cases of using Ray Train in production:
1. Canva uses Ray Train + Ray Data to cut down Stable Diffusion training costs by 3.7x. Read this [Anyscale blog post here](https://www.anyscale.com/blog/scalable-and-cost-efficient-stable-diffusion-pre-training-with-ray) and the [Canva  case study here](https://www.anyscale.com/resources/case-study/how-canva-built-a-modern-ai-platform-using-anyscale)
2. Anyscale uses Ray Train + Deepspeed to finetune language models. Read more [here](https://github.com/ray-project/ray/tree/master/doc/source/templates/04_finetuning_llms_with_deepspeed).


In [None]:
# Run this cell for file cleanup 
!rm -rf {storage_folder}/training/
!rm -rf {local_path}
!rm -rf {storage_folder}/data