## 8. Define Ray Train loop (with validation, checkpointing, and Ray-managed metrics)

Define the `train_loop_per_worker`, the core function executed by each Ray Train worker.  
This loop handles distributed training, validation, and checkpointing with Ray-managed metrics.

Each worker receives its own shard of the training and validation datasets using `get_dataset_shard()`.  
Batches are streamed directly into PyTorch via `iter_torch_batches()`, ensuring efficient, fully distributed data loading.

During each epoch:
- Compute average **training** and **validation** MSE losses.  
- On **rank 0** only, save a temporary checkpoint (model weights + epoch metadata) using `tempfile.TemporaryDirectory()`.  
- Call `ray.train.report()` to report metrics and attach the checkpoint; other workers report metrics only.

All metrics are automatically captured by Ray and made available in `result.metrics_dataframe`, enabling progress tracking and fault-tolerant recovery without extra logging logic.

In [None]:
# 08. Define Ray Train loop (with val loss, checkpointing, and Ray-managed metrics)

def train_loop_per_worker(config):
    import tempfile
    # ---------------- Dataset shards -> PyTorch-style iterators ---------------- #
    train_ds = get_dataset_shard("train")
    val_ds   = get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)
    val_loader   = val_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)

    # ---------------- Model / Optimizer ---------------- #
    model = MatrixFactorizationModel(
        num_users=config["num_users"],
        num_items=config["num_items"],
        embedding_dim=config.get("embedding_dim", 64),
    )
    model = prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.get("lr", 1e-3))

    # ---------------- Checkpointing setup ---------------- #
    rank = get_context().get_world_rank()
    start_epoch = 0

    # If a checkpoint exists (auto-resume), load it
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as ckpt_dir:
            model.load_state_dict(
                torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu")
            )
            start_epoch = torch.load(os.path.join(ckpt_dir, "meta.pt")).get("epoch", 0) + 1
        if rank == 0:
            print(f"[Rank {rank}] ✅ Resumed from checkpoint at epoch {start_epoch}")

    # ---------------- Training loop ---------------- #
    for epoch in range(start_epoch, config.get("epochs", 5)):
        # ---- Train ----
        model.train()
        train_losses = []
        for batch in train_loader:
            user = batch["user_idx"].long()
            item = batch["item_idx"].long()
            rating = batch["rating"].float()

            pred = model(user, item)
            loss = F.mse_loss(pred, rating)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        avg_train_loss = sum(train_losses) / max(1, len(train_losses))

        # ---- Validate ----
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                user = batch["user_idx"].long()
                item = batch["item_idx"].long()
                rating = batch["rating"].float()

                pred = model(user, item)
                loss = F.mse_loss(pred, rating)
                val_losses.append(loss.item())

        avg_val_loss = sum(val_losses) / max(1, len(val_losses))

        # Console log (optional)
        if rank == 0:
            print(f"[Epoch {epoch}] Train MSE: {avg_train_loss:.4f} | Val MSE: {avg_val_loss:.4f}")

        metrics = {
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
        }

        # ---- Save checkpoint & report (rank 0 attaches checkpoint; others report metrics only) ----
        if rank == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(model.state_dict(), os.path.join(tmpdir, "model.pt"))
                torch.save({"epoch": epoch}, os.path.join(tmpdir, "meta.pt"))
                ckpt_out = Checkpoint.from_directory(tmpdir)
                report(metrics, checkpoint=ckpt_out)
        else:
            report(metrics, checkpoint=None)

### 9. Launch distributed training with Ray Train  
Now, launch distributed training using `TorchTrainer`, Ray Train’s high-level orchestration interface. Provide it with:

- Your custom `train_loop_per_worker` function
- A `train_config` dictionary that specifies model dimensions, learning rate, and number of epochs
- The sharded `train` and `val` Ray Datasets
- A `ScalingConfig` that sets the number of workers and GPU usage

Also, configure checkpointing and fault tolerance:
- Ray keeps all checkpoints checkpoints for later plotting
- Failed workers retry up to two times

Calling `trainer.fit()` kicks off training across the cluster. If any workers fail or disconnect, Ray restarts them and resume from the latest checkpoint.

In [None]:
# 09. Launch distributed training with Ray TorchTrainer

# Define config params (use Ray-derived counts)
train_config = {
    "num_users": NUM_USERS,
    "num_items": NUM_ITEMS,
    "embedding_dim": 64,
    "lr": 1e-3,
    "epochs": 20,
}

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_config,
    scaling_config=ScalingConfig(
        num_workers=8,       # Increase as needed
        use_gpu=True         # Set to True if training on GPUs
    ),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="mf_ray_train",
        storage_path="/mnt/cluster_storage/rec_sys_tutorial/results",
        checkpoint_config=CheckpointConfig(num_to_keep=20),
        failure_config=FailureConfig(max_failures=2)
    )
)

# Run distributed training
result = trainer.fit()

### 10. Plot train and validation loss curves

After training, retrieve the full metrics history directly from **Ray Train’s internal tracking** via `result.metrics_dataframe`.

This DataFrame automatically includes all reported metrics across epochs (e.g., `train_loss`, `val_loss`) for every call to `ray.train.report()`.  
You use it to visualize model convergence and ensure the training loop, checkpointing, and reporting worked correctly.

The plotted curves show how the **training** and **validation** MSE losses evolve over time—confirming whether the model is learning effectively and when it begins to stabilize.

In [None]:
# 10. Plot train/val loss curves (from Ray Train results)

# Pull the full metrics history Ray stored for this run
df = result.metrics_dataframe.copy()

# Keep only the columns we need (guard against extra columns)
cols = [c for c in ["epoch", "train_loss", "val_loss"] if c in df.columns]
df = df[cols].dropna()

# If multiple rows per epoch exist, keep the last report per epoch
if "epoch" in df.columns:
    df = df.sort_index().groupby("epoch", as_index=False).last()

# Plot
plt.figure(figsize=(7, 4))
if "train_loss" in df.columns:
    plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
if "val_loss" in df.columns:
    plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Matrix Factorization - Loss per Epoch")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()