## 5. Distributed Train loop with checkpointing 

This per-worker function demonstrates:

* **Ray Data to PyTorch**: `iter_torch_batches()` each epoch  
* **Lightning-on-Ray**: single-GPU trainer per worker  
* **Fault tolerance**: resume from the latest Ray Train checkpoint  
* **checkpoint**: saves the metadata and metrics every epoch

In [None]:
# 05. Ray Train Lightning-native training loop

def train_loop(config):
    import os, tempfile, torch, warnings
    import lightning.pytorch as pl
    from ray.train import get_checkpoint, get_context
    from ray.train.lightning import (
        RayLightningEnvironment,
        RayDDPStrategy,
        RayTrainReportCallback,
        prepare_trainer,
    )

    warnings.filterwarnings(
        "ignore", message="barrier.*using the device under current context"
    )

    # ---- Ray Dataset shards → iterable torch batches ----
    train_ds = ray.train.get_dataset_shard("train")
    val_ds   = ray.train.get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))
    val_loader   = val_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))

    # ---- Model ----
    model = DiffusionPolicy()

    # ---- Local scratch for PL checkpoints (Ray will persist to storage_path) ----
    CKPT_ROOT = os.path.join(tempfile.gettempdir(), "ray_pl_ckpts")
    os.makedirs(CKPT_ROOT, exist_ok=True)

    # ---- Lightning Trainer configured for Ray ----
    trainer = pl.Trainer(
        max_epochs=config.get("epochs", 10),
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[
            RayTrainReportCallback(),       # forwards metrics + ckpt to Ray
            pl.callbacks.ModelCheckpoint(   # local PL checkpoints each epoch
                dirpath=CKPT_ROOT,
                filename="epoch-{epoch:03d}",
                every_n_epochs=1,
                save_top_k=-1,
                save_last=True,
            ),
        ],
        default_root_dir=CKPT_ROOT,
        enable_progress_bar=False,
        check_val_every_n_epoch=1,
    )

    # ---- Prepare trainer for Ray environment ----
    trainer = prepare_trainer(trainer)

    # ---- Resume from Ray checkpoint if available ----
    ckpt_path = None
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            candidate = os.path.join(d, "checkpoint.ckpt")
            if os.path.exists(candidate):
                ckpt_path = candidate
                if get_context().get_world_rank() == 0:
                    print(f"✅ Resuming from Lightning checkpoint: {ckpt_path}")

    # ---- Run training (Lightning owns the loop) ----
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)


### 6. Launch Ray TorchTrainer

Eight A10G workers train in parallel.  
`RunConfig` keeps the **five most recent checkpoints** and automatically restarts
up to **3** times on failure.

In [None]:
# 06. Launch distributed training with Ray TorchTrainer

trainer = TorchTrainer(
    train_loop,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="pendulum_diffusion_ft",
        storage_path="/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_results",
        checkpoint_config=CheckpointConfig(
            checkpoint_frequency=1,
            num_to_keep=5,
            checkpoint_score_attribute="epoch",
            checkpoint_score_order="max",
        ),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint  # latest Ray-managed Lightning checkpoint

### 7. Plot train / val loss  

Visualize convergence using Ray’s built-in metrics.  
`result.metrics_dataframe` automatically collects the losses logged by Lightning each epoch.

In [None]:
# 07. Plot training and validation loss (Ray + Lightning integration)

df = result.metrics_dataframe
print(df.head())  # optional sanity check

if "train_loss" not in df.columns or "val_loss" not in df.columns:
    raise ValueError("train_loss / val_loss missing. Did you log them via self.log()?")

plt.figure(figsize=(7, 4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Pendulum Diffusion - Loss per Epoch (Ray Train + Lightning)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()