## ðŸ”„ 02 Â· Integrating Ray Train with Ray Data  
In this module youâ€™ll extend distributed training with **Ray Train** by adding **Ray Data** to the pipeline. Instead of relying on a local PyTorch DataLoader, youâ€™ll stream batches directly from a distributed **Ray Dataset**, enabling scalable preprocessing and just-in-time data loading across the cluster.  

### What youâ€™ll learn & take away  
* When to integrate **Ray Data** with Ray Train â€” e.g., for CPU-heavy preprocessing, online augmentations, or multi-format data ingestion  
* How to replace `DataLoader` with **`iter_torch_batches()`** to stream batches into your training loop  
* How to shard, shuffle, and preprocess data in parallel across the cluster before feeding it into GPUs  
* How to define a **training loop** that consumes Ray Dataset shards instead of DataLoader tuples  
* How to prepare datasets (For example, Parquet format) so they can be efficiently read and transformed with Ray Data  
* How to pass Ray Datasets into the `TorchTrainer` with the `datasets` parameter  

> With Ray Data, you can scale preprocessing and training independently â€” CPUs handle input pipelines, GPUs focus on training â€” ensuring **higher utilization and throughput** in your distributed workloads.  

Note that the code blocks for this module will depend on the previous module, **Introduction to Ray Train**.

### ðŸ”Ž Integrating Ray Train with Ray Data  

Use both Ray Train and Ray Data when you face one of the following challenges:  
| Challenge | Detail | Solution |
| --- | --- | --- |
| Need to perform online or just-in-time data processing | The training pipeline requires processing data on the fly, such as data augmentation, normalization, or other transformations that may differ for each training epoch. | Ray Train's integration with Ray Data makes it easy to implement just-in-time data processing. |
| Need to improve hardware utilization | Training and data processing need to be scaled independently to keep GPUs fully utilized, especially when preprocessing is CPU-intensive. | Ray Data can distribute data processing across multiple CPU nodes, while Ray Train runs the training loop on GPUs. |
| Need a consistent interface for loading data | The training process may need to load data from various sources, such as Parquet, CSV, or lakehouses. | Ray Data provides a consistent interface for loading, shuffling, sharding, and batching data for training loops. |

## 01 Â· Define Training Loop with Ray Data  

Here we reimplement the training loop, but this time using **Ray Data** instead of a PyTorch `DataLoader`.  

Key differences from the previous version:  
- **Data loader** â†’ Built with `build_data_loader_ray_train_ray_data()`, which streams batches from a Ray Dataset shard (details in the following block).  
- **Batching** â†’ Still split by `global_batch_size // world_size`, but batches are now **dictionaries** with keys `"image"` and `"label"`.  
- **No device management needed** â†’ Ray Data automatically moves batches to the correct device, so we no longer call `sampler.set_epoch()` or `to("cuda")`.  

The rest of the loop (forward pass, loss computation, backward pass, optimizer step, metric logging, and checkpointing) stays the same.  

This pattern shows how seamlessly **Ray Data integrates with Ray Train**, replacing `DataLoader` while keeping the training logic identical.  

In [None]:
# 01. Training loop using Ray Data

def train_loop_ray_train_ray_data(config: dict):
    # Same as before: define loss, model, optimizer
    criterion = CrossEntropyLoss()
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-3)
    
    # Different: build data loader from Ray Data instead of PyTorch DataLoader
    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size) 
    
    # Same: loop over epochs
    for epoch in range(config["num_epochs"]):
        # Different: no sampler.set_epoch(), Ray Data handles shuffling internally

        # Different: batches are dicts {"image": ..., "label": ...} not tuples
        for batch in data_loader: 
            outputs = model(batch["image"])
            loss = criterion(outputs, batch["label"])
            optimizer.zero_grad()
            loss.backward() 
            optimizer.step()

        # Same: report metrics and save checkpoint each epoch
        metrics = print_metrics_ray_train(loss, epoch)
        save_checkpoint_and_metrics_ray_train(model, metrics)

## 02 Â· Build DataLoader from Ray Data  

Instead of using PyTorchâ€™s `DataLoader`, we now build a loader from a **Ray Dataset shard**.  

- `ray.train.get_dataset_shard("train")` â†’ retrieves the shard of the training dataset assigned to the current worker.  
- `.iter_torch_batches()` â†’ streams the shard as PyTorch-compatible batches.  
  * Each batch is a **dictionary** (e.g., `{"image": tensor, "label": tensor}`).  
  * Supports options like `batch_size` and `prefetch_batches` for performance tuning.  

This integration ensures that data is **sharded, shuffled, and moved to the right device automatically**, while still looking and feeling like a familiar PyTorch data loader.  

**Note:** Use [`iter_torch_batches`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.iter_torch_batches.html) to build a PyTorch-compatible data loader from a Ray Dataset. 

In [None]:
# 02. Build a Ray Dataâ€“backed data loader

def build_data_loader_ray_train_ray_data(batch_size: int, prefetch_batches: int = 2):

    # Different: instead of creating a PyTorch DataLoader,
    # fetch the training dataset shard for this worker
    dataset_iterator = ray.train.get_dataset_shard("train")

    # Convert the shard into a PyTorch-style iterator
    # - Returns dict batches: {"image": ..., "label": ...}
    # - prefetch_batches controls pipeline buffering
    data_loader = dataset_iterator.iter_torch_batches(
        batch_size=batch_size, prefetch_batches=prefetch_batches
    )
    
    return data_loader

## 03 Â· Prepare Dataset for Ray Data  

Ray Data works best with data in **tabular formats** such as Parquet.  
In this step we:  

- Convert the MNIST dataset into a **pandas DataFrame** with two columns:  
  * `"image"` â†’ raw image arrays  
  * `"label"` â†’ digit class (0â€“9)  
- Write the DataFrame to disk in **Parquet format** under `/mnt/cluster_storage/`.  

Parquet is efficient for both reading and distributed processing, making it a good fit for Ray Data pipelines.  

In [None]:
# 03. Convert MNIST dataset into Parquet for Ray Data

# Build a DataFrame with image arrays and labels
df = pd.DataFrame({
    "image": dataset.data.tolist(),   # raw image pixels (as lists)
    "label": dataset.targets          # digit labels 0â€“9
})

# Persist the dataset in Parquet format (columnar, efficient for Ray Data)
df.to_parquet("/mnt/cluster_storage/cifar10.parquet")

### 04 Â· Load Dataset into Ray Data  

Now that the training data is stored as Parquet, we can load it back into a **Ray Dataset**:  

- Use `ray.data.read_parquet()` to create a distributed Ray Dataset from the Parquet file.  
- Each row has two columns: `"image"` (raw pixel array) and `"label"` (digit class).  
- The dataset is automatically **sharded across the Ray cluster**, so multiple workers can read and process it in parallel.  

This Ray Dataset will later be passed to the `TorchTrainer` for distributed training.  


In [None]:
# 04. Load the Parquet dataset into a Ray Dataset

# Read the Parquet file â†’ creates a distributed Ray Dataset
train_ds = ray.data.read_parquet("/mnt/cluster_storage/cifar10.parquet")


## 05 Â· Define Image Transformation  

To make the dataset usable by PyTorch, we need to preprocess the raw image arrays with the same steps that pytorch data loader does.  

- Define a function `transform_images(row)` that:  
  * Converts the `"image"` array from `numpy` into a PIL image.  
  * Applies the standard PyTorch transforms:  
    - `ToTensor()` â†’ converts the image to a tensor.  
    - `Normalize((0.5,), (0.5,))` â†’ scales pixel values to the range [-1, 1].  
  * Replaces the `"image"` entry in the row with the transformed tensor.  

This function will later be applied in parallel across the Ray Dataset.  


In [None]:
# 05. Define preprocessing transform for Ray Data

def transform_images(row: dict):
    # Convert numpy array to a PIL image, then apply TorchVision transforms
    transform = Compose([
        ToTensor(),              # convert to tensor
        Normalize((0.5,), (0.5,)) # normalize to [-1, 1]
    ])

    # Ensure image is in uint8 before conversion
    image_arr = np.array(row["image"], dtype=np.uint8)

    # Apply transforms and replace the "image" field with tensor
    row["image"] = transform(Image.fromarray(image_arr))
    return row

<div class="alert alert-block alert-info">

**Note**: Unlike the PyTorch DataLoader, the preprocessing can now occur on any node in the cluster.

The data will be passed to training workers via the ray object store (a distributed in-memory object store).

<div>

### 06 Â· Apply Transformations with Ray Data  

Now we apply the preprocessing function to the dataset using `map()`:  

- `train_ds.map(transform_images)` â†’ runs the `transform_images` function on every row of the dataset.  
- Transformations are executed **in parallel across the cluster**, so preprocessing can scale independently of training.  
- The transformed dataset now has:  
  * `"image"` â†’ normalized PyTorch tensors  
  * `"label"` â†’ unchanged integer labels  

This makes the dataset ready to be streamed into the training loop.  

In [None]:
# 06. Apply the preprocessing transform across the Ray Dataset

# Run transform_images() on each row (parallelized across cluster workers)
train_ds = train_ds.map(transform_images)

## 07 Â· Configure `TorchTrainer` with Ray Data  

Now we connect the Ray Dataset to the training loop using the `datasets` parameter in `TorchTrainer`:  

- **`datasets={"train": train_ds}`** â†’ makes the transformed dataset available to the training loop as the `"train"` shard.  
- **`train_loop_ray_train_ray_data`** â†’ the per-worker training loop that consumes Ray Data batches.  
- **`train_loop_config`** â†’ passes hyperparameters (`num_epochs`, `global_batch_size`).  
- **`scaling_config`** â†’ specifies the number of workers and GPUs to use (same as before).  
- **`run_config`** â†’ defines storage for checkpoints and metrics.  

This setup allows Ray Train to automatically shard and stream the Ray Dataset into each worker during training.  

In [None]:
# 07. Configure TorchTrainer with Ray Data integration

# Wrap Ray Dataset in a dict â†’ accessible as "train" inside the training loop
datasets = {"train": train_ds}

trainer = TorchTrainer(
    train_loop_ray_train_ray_data,  # training loop consuming Ray Data
    train_loop_config={             # hyperparameters
        "num_epochs": 1,
        "global_batch_size": 512,
    },
    scaling_config=scaling_config,  # number of workers + GPU/CPU resources
    run_config=RunConfig(
        storage_path=storage_path, 
        name="dist-cifar-res18-ray-data"
    ),                              # where to store checkpoints/logs
    datasets=datasets,              # provide Ray Dataset shards to workers
)

### 08 Â· Launch Training with Ray Data  

Finally, call `trainer.fit()` to start the distributed training job.  

- Ray will automatically:  
  * Launch workers according to the `scaling_config`.  
  * Stream sharded, preprocessed batches from the Ray Dataset into each worker.  
  * Run the training loop (`train_loop_ray_train_ray_data`) on every worker in parallel.  
  * Report metrics and save checkpoints to the configured storage path.  

With this call, you now have a fully **end-to-end distributed pipeline** where **Ray Data handles ingestion + preprocessing** and **Ray Train handles multi-GPU training**.  

In [None]:
# 08. Start the distributed training job with Ray Data integration

# Launches the training loop across all workers
# - Streams preprocessed Ray Dataset batches into each worker
# - Reports metrics and checkpoints to cluster storage
trainer.fit()