## 05. Distributed Train Loop w/ Checkpointing 

This per-worker function demonstrates:

* **Ray Data ➜ PyTorch**: `iter_torch_batches()` each epoch  
* **Lightning-on-Ray**: single-GPU trainer per worker  
* **Fault tolerance**: resume from the latest Ray Train checkpoint  
* **Manual checkpoint**: save `model.pt` + `meta.pt` every epoch

In [None]:
# 05. Training loop

# Training function that runs on each Ray worker
def train_loop(config):

    # ---------- Paths for logs & checkpoints ----------
    LOG_PATH  = "/mnt/cluster_storage/pendulum_diffusion/epoch_metrics.json"
    CKPT_ROOT = "/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_ckpts"

    rank = get_context().get_world_rank()

    # Create log/checkpoint dirs on rank 0
    if rank == 0:
        os.makedirs(CKPT_ROOT, exist_ok=True)
        if not get_checkpoint() and os.path.exists(LOG_PATH):
            os.remove(LOG_PATH)

    # ---------- Load Ray Dataset shards ----------
    train_ds = ray.train.get_dataset_shard("train")
    val_ds   = ray.train.get_dataset_shard("val")

    # ---------- Instantiate model ----------
    model = DiffusionPolicy()
    start_epoch = 0

    # ---------- Resume from checkpoint ----------
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            model.load_state_dict(torch.load(os.path.join(d, "model.pt"), map_location="cpu"))
            start_epoch = torch.load(os.path.join(d, "meta.pt")).get("epoch", 0) + 1
        if rank == 0:
            print(f"[Rank {rank}] Resumed from checkpoint at epoch {start_epoch}")

    # ---------- Lightning Trainer ----------
    trainer = pl.Trainer(
        max_epochs=config.get("epochs", 10),
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        check_val_every_n_epoch=1,
    )

    # ---------- Training Loop ----------
    for epoch in range(start_epoch, config.get("epochs", 10)):

        # Re-materialize fresh batches each epoch (avoid stale iterator errors)
        train_data = list(train_ds.iter_torch_batches(
            batch_size=32,
            local_shuffle_buffer_size=1024,
            prefetch_batches=1,
            drop_last=True
        ))
        val_data = list(val_ds.iter_torch_batches(
            batch_size=32,
            local_shuffle_buffer_size=1024,
            prefetch_batches=1,
            drop_last=True
        ))

        # Wrap lists in PyTorch DataLoaders
        train_loader = DataLoader(train_data, batch_size=None)
        val_loader   = DataLoader(val_data, batch_size=None)

        # Run one epoch (advance trainer manually)
        trainer.fit_loop.max_epochs = epoch + 1
        trainer.fit_loop.current_epoch = epoch
        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        # ---------- Save checkpoint ----------
        if rank == 0:
            out_dir = os.path.join(CKPT_ROOT, f"epoch_{epoch}_{uuid.uuid4().hex}")
            os.makedirs(out_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
            torch.save({"epoch": epoch}, os.path.join(out_dir, "meta.pt"))
            ckpt_out = Checkpoint.from_directory(out_dir)
        else:
            ckpt_out = None

        # Report metrics + checkpoint back to Ray Train
        report({"epoch": epoch}, checkpoint=ckpt_out)

### 06. Launch Ray TorchTrainer

Eight A10G workers train in parallel.  
`RunConfig` keeps the **five most recent checkpoints** and automatically restarts
up to **3** times on failure.

In [None]:
# 06. Launch Ray Trainer

# Configure Ray TorchTrainer to run the distributed training job
trainer = TorchTrainer(
    train_loop,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="pendulum_diffusion_ft",
        storage_path="/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_results",
        checkpoint_config=CheckpointConfig(
            checkpoint_frequency=1,       # save every epoch
            num_to_keep=5,
            checkpoint_score_attribute="epoch",
            checkpoint_score_order="max",
        ),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint        # checkpoint from last reported epoch

### 07. Plot Train / Val Loss

Now, read the shared `epoch_metrics.json` and visualize convergence curves.

In [None]:
# 07. Plot training and validation loss

# Load training logs from shared file and plot losses
log_path = "/mnt/cluster_storage/pendulum_diffusion/epoch_metrics.json"

with open(log_path, "r") as f:
    logs = json.load(f)

# Convert logs to pandas DataFrame
df = pd.DataFrame(logs)
df["val_loss"] = pd.to_numeric(df["val_loss"], errors="coerce")  # handle None
df_grouped = df.groupby("epoch", as_index=False).mean(numeric_only=True)


# Plot training and validation loss curves
plt.figure(figsize=(8, 5))
plt.plot(df_grouped["epoch"], df_grouped["train_loss"], marker="o", label="Train Loss")
plt.plot(df_grouped["epoch"], df_grouped["val_loss"], marker="o", label="Val Loss")
plt.title("Training & Validation Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()