## 9. Ray Train `train_loop` (Lightning + Ray integration)  
Define the core training logic that runs **once per Ray worker**.  
This version uses **PyTorch Lightning with Ray Train V2** for automatic distributed setup, checkpointing, and metric reporting.  
`RayDDPStrategy` and `RayTrainReportCallback` handle synchronization, while `prepare_trainer()` wires up Ray’s environment.  
Lightning manages all checkpoints (`checkpoint.ckpt`) and metrics transparently, ensuring fully **fault-tolerant, resumable multi-GPU training**.

In [None]:
# 09. Train loop for Ray TorchTrainer (RayDDP + Lightning-native, NEW API aligned)

def train_loop(config):
    """
    Lightning-owned loop with Ray integration:
      - RayDDPStrategy for multi-worker DDP
      - RayLightningEnvironment for ranks/addrs
      - RayTrainReportCallback to forward metrics + checkpoints to Ray
      - Resume from the Ray-provided Lightning checkpoint ("checkpoint.ckpt")
    """
    import warnings
    warnings.filterwarnings(
        "ignore",
        message="barrier.*using the device under current context",
    )
    import os
    import torch
    import lightning.pytorch as pl
    from ray.train import get_checkpoint, get_context
    from ray.train.lightning import (
        RayLightningEnvironment,
        RayDDPStrategy,
        RayTrainReportCallback,
        prepare_trainer,
    )

    # ---- Data shards from Ray Data → iterable loaders ----
    train_ds = ray.train.get_dataset_shard("train")
    val_ds   = ray.train.get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))
    val_loader   = val_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))

    # ---- Model ----
    model = PixelDiffusion()

    # ---- Lightning Trainer configured for Ray ----
    CKPT_ROOT = os.path.join(tempfile.gettempdir(), "ray_pl_ckpts")
    os.makedirs(CKPT_ROOT, exist_ok=True)

    trainer = pl.Trainer(
        max_epochs=config.get("epochs", 10),
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[
            RayTrainReportCallback(),
            pl.callbacks.ModelCheckpoint(
                dirpath=CKPT_ROOT,         # local scratch is fine (or leave None to use default)
                filename="epoch-{epoch:03d}",
                every_n_epochs=1,
                save_top_k=-1,
                save_last=True,
            ),
        ],
        default_root_dir=CKPT_ROOT,        # also local
        enable_progress_bar=False,
        check_val_every_n_epoch=1,
    )

    # Wire up ranks/world size with Ray
    trainer = prepare_trainer(trainer)

    # ---- Resume from latest Ray-provided Lightning checkpoint (if any) ----
    ckpt_path = None
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            candidate = os.path.join(d, "checkpoint.ckpt")
            if os.path.exists(candidate):
                ckpt_path = candidate
                if get_context().get_world_rank() == 0:
                    print(f"✅ Resuming from Lightning checkpoint: {ckpt_path}")

    # ---- Let Lightning own the loop ----
    trainer.fit(
        model,
        train_dataloaders=train_loader,
        val_dataloaders=val_loader,
        ckpt_path=ckpt_path,
    )

### 10. Launch distributed Training with TorchTrainer  
Ask for **8 GPU workers**, keep the five most-recent checkpoints, and allow up to one automatic retry.  
`result.checkpoint` captures the checkpoint from the highest epoch (because you used `epoch` as the score attribute, you can change this to other metrics such as validation loss or training loss).

In [None]:
# 10. Launch distributed training (same API, now Lightning-native inside)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="food101_diffusion_ft",
        storage_path="/mnt/cluster_storage/generative_cv/food101_diffusion_results",
        checkpoint_config=CheckpointConfig(
            checkpoint_frequency=1,          # Ray keeps N most recent Lightning ckpts
            num_to_keep=5,
            checkpoint_score_attribute="epoch",
            checkpoint_score_order="max",
        ),
        failure_config=FailureConfig(max_failures=1),
    ),
)

result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint

### 11. Plot loss curves  
Visualize training and validation loss directly from Ray’s tracked metrics.  
`result.metrics_dataframe` automatically aggregates values logged by Lightning during training.  
Plotting these losses provides a quick health check to confirm steady convergence and model generalization.

In [None]:
# 11. Plot train/val loss curves (Ray + Lightning integration)

# Ray stores all metrics emitted by Lightning in a dataframe
df = result.metrics_dataframe

# Display first few rows (optional sanity check)
print(df.head())

# Convert and clean up
if "train_loss" not in df.columns or "val_loss" not in df.columns:
    raise ValueError("Expected train_loss and val_loss in metrics. "
                     "Did you call self.log('train_loss') / self.log('val_loss') in PixelDiffusion?")

plt.figure(figsize=(7, 4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Pixel Diffusion – Loss per Epoch (Ray Train + Lightning)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()