## 01 · Define Training Loop with Ray Data  

Here we reimplement the training loop, but this time using **Ray Data** instead of a PyTorch `DataLoader`.  

Key differences from the previous version:  
- **Data loader** → Built with `build_data_loader_ray_train_ray_data()`, which streams batches from a Ray Dataset shard (details in the following block).  
- **Batching** → Still split by `global_batch_size // world_size`, but batches are now **dictionaries** with keys `"image"` and `"label"`.  
- **No device management needed** → Ray Data automatically moves batches to the correct device, so we no longer call `sampler.set_epoch()` or `to("cuda")`.  

The rest of the loop (forward pass, loss computation, backward pass, optimizer step, metric logging, and checkpointing) stays the same.  

This pattern shows how seamlessly **Ray Data integrates with Ray Train**, replacing `DataLoader` while keeping the training logic identical.  

In [None]:
# 01. Training loop using Ray Data

def train_loop_ray_train_ray_data(config: dict):
    # Same as before: define loss, model, optimizer
    criterion = CrossEntropyLoss()
    model = load_model_ray_train()
    optimizer = Adam(model.parameters(), lr=1e-3)
    
    # Different: build data loader from Ray Data instead of PyTorch DataLoader
    global_batch_size = config["global_batch_size"]
    batch_size = global_batch_size // ray.train.get_context().get_world_size()
    data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size) 
    
    # Same: loop over epochs
    for epoch in range(config["num_epochs"]):
        # Different: no sampler.set_epoch(), Ray Data handles shuffling internally

        # Different: batches are dicts {"image": ..., "label": ...} not tuples
        for batch in data_loader: 
            outputs = model(batch["image"])
            loss = criterion(outputs, batch["label"])
            optimizer.zero_grad()
            loss.backward() 
            optimizer.step()

        # Same: report metrics and save checkpoint each epoch
        metrics = print_metrics_ray_train(loss, epoch)
        save_checkpoint_and_metrics_ray_train(model, metrics)