## 09 · Ray Train `train_loop` (Checkpoint + Resume)  
Core training logic run **once per Ray worker**:  
1. Shard-aware DataLoaders with `get_dataset_shard`.  
2. Auto-resume from the latest Ray Checkpoint (if present).  
3. Manual per-epoch checkpointing: save `model.pt` + `meta.pt`, then call `report(metrics, checkpoint=…)`.  
This makes the run fully **fault-tolerant**---if a worker crashes, Ray restarts the group and re-enters the loop with the latest checkpoint.

In [None]:
# 09. Train Loop for Ray TorchTrainer

def train_loop(config):
    """Ray Train per-worker function with checkpointing and resume support."""
    import os, torch, uuid, json
    from ray.train import get_checkpoint, get_context, report, Checkpoint

    # Paths
    LOG_PATH = "/mnt/cluster_storage/generative_cv/epoch_metrics.json"
    CKPT_ROOT = "/mnt/cluster_storage/generative_cv/food101_diffusion_ckpts"

    rank = get_context().get_world_rank()
    if rank == 0:
        os.makedirs(CKPT_ROOT, exist_ok=True)
        if not get_checkpoint() and os.path.exists(LOG_PATH):
            os.remove(LOG_PATH)

    # Data
    train_ds = ray.train.get_dataset_shard("train")
    val_ds   = ray.train.get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=32)
    val_loader   = val_ds.iter_torch_batches(batch_size=32)

    # Model
    model = PixelDiffusion()
    start_epoch = 0

    # Resume from checkpoint if present
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            model.load_state_dict(torch.load(os.path.join(d, "model.pt"), map_location="cpu"))
            start_epoch = torch.load(os.path.join(d, "meta.pt")).get("epoch", 0) + 1
        if rank == 0:
            print(f"[Rank {rank}] Resumed from checkpoint at epoch {start_epoch}")

    # Trainer
    trainer = pl.Trainer(
        max_epochs=config.get("epochs", 10),
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        check_val_every_n_epoch=1,
    )

    # Train loop: run each epoch, checkpoint manually
    for epoch in range(start_epoch, config.get("epochs", 10)):
        trainer.fit_loop.max_epochs = epoch + 1
        trainer.fit_loop.current_epoch = epoch
        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        if rank == 0:
            # Save model checkpoint
            out_dir = os.path.join(CKPT_ROOT, f"epoch_{epoch}_{uuid.uuid4().hex}")
            os.makedirs(out_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
            torch.save({"epoch": epoch}, os.path.join(out_dir, "meta.pt"))
            ckpt_out = Checkpoint.from_directory(out_dir)
        else:
            ckpt_out = None

        # Report with checkpoint so Ray saves it
        report({"epoch": epoch}, checkpoint=ckpt_out)

### 10 · Launch Distributed Training with TorchTrainer  
Ask for **8 GPU workers**, keep the five most-recent checkpoints, and allow up to three automatic retries.  
`result.checkpoint` captures the checkpoint from the highest epoch (because you used `epoch` as the score attribute---you can change this to other metrics such as validation loss or training loss).

In [None]:
# 10. Launch Distributed Training

trainer = TorchTrainer(
    train_loop,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="food101_diffusion_ft",
        storage_path="/mnt/cluster_storage/generative_cv/food101_diffusion_results",
        checkpoint_config=CheckpointConfig(
            checkpoint_frequency=1,
            num_to_keep=5,
            checkpoint_score_attribute="epoch",
            checkpoint_score_order="max",
        ),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint  # checkpoint from highest reported epoch (you can change score attr)

### 11 · Plot Loss Curves  
Parse the JSON written by `PixelDiffusion.on_train_epoch_end`, convert to a DataFrame, and render Train vs. Val MSE loss.  
Good practice for quick health checks without external tooling.

**Why is validation loss lower than training loss?**  
You measure training loss *before* weights update and includes fresh noise every step, while validation runs in `eval()` mode with no gradient updates, often making it slightly lower, especially early in training.  
This is normal behavior in this sort of scenario and usually means the model is generalizing well, and not over-fitting.


In [None]:
# 11. Plot Train/Val Loss Curves

LOG_PATH = "/mnt/cluster_storage/generative_cv/epoch_metrics.json"
with open(LOG_PATH, "r") as f:
    logs = json.load(f)

df = pd.DataFrame(logs)
df["val_loss"] = pd.to_numeric(df["val_loss"], errors="coerce")

plt.figure(figsize=(7,4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"],   marker="o", label="Val")
plt.xlabel("Epoch"); plt.ylabel("MSE Loss"); plt.title("Pixel Diffusion - Loss per Epoch")
plt.grid(True); plt.legend(); plt.tight_layout(); plt.show()