## 08 · Define Ray Train Loop (with Validation, Logging, and Checkpointing)  
Define the `train_loop_per_worker`, the core function Ray executes on each worker. This loop handles everything from dataset loading and model training to validation, logging, and checkpointing.

Each worker receives its own shard of the training and validation sets with `get_dataset_shard`. Use `iter_torch_batches` to stream batches directly into PyTorch.

For each epoch:
- Compute the average training loss across all batches.
- Evaluate the model on the validation set and compute the average validation loss.
- On rank 0 only, save a checkpoint (model weights + epoch metadata) and append the losses to a shared JSON log.

Finally, call `ray.train.report` to expose metrics and checkpoints to the Ray controller. This enables fault tolerance, auto-resume, and metrics tracking with zero additional setup.

In [None]:
# 08. Define Ray Train Loop (with Val Loss + Checkpointing + Logging)

def train_loop_per_worker(config):

    # Get dataset shards for this worker
    train_ds = get_dataset_shard("train")
    val_ds   = get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)
    val_loader   = val_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)

    # Create model and optimizer
    model = MatrixFactorizationModel(
        num_users=config["num_users"],
        num_items=config["num_items"],
        embedding_dim=config.get("embedding_dim", 64)
    )
    model = prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.get("lr", 1e-3))

    # Paths for checkpointing and logging
    CKPT_DIR = "/mnt/cluster_storage/rec_sys_tutorial/checkpoints"
    LOG_PATH = "/mnt/cluster_storage/rec_sys_tutorial/epoch_metrics.json"
    os.makedirs(CKPT_DIR, exist_ok=True)
    rank = int(os.environ.get("RANK", "0"))  # Worker rank
    start_epoch = 0

    # Resume from checkpoint if available
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as ckpt_dir:
            model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu"))
            start_epoch = torch.load(os.path.join(ckpt_dir, "meta.pt")).get("epoch", 0) + 1
        if rank == 0:
            print(f"[Rank {rank}] ✅ Resumed from checkpoint at epoch {start_epoch}")

    # Clean up log file on first run (only if not resuming)
    if rank == 0 and start_epoch == 0 and os.path.exists(LOG_PATH):
        os.remove(LOG_PATH)

    # ----------------- Training Loop ----------------- #
    for epoch in range(start_epoch, config.get("epochs", 5)):
        model.train()
        train_losses = []

        # Train over each batch
        for batch in train_loader:
            user = batch["user_idx"].long()
            item = batch["item_idx"].long()
            rating = batch["rating"].float()

            pred = model(user, item)
            loss = F.mse_loss(pred, rating)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        avg_train_loss = sum(train_losses) / len(train_losses)

        # ---------- Validation Pass ----------
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                user = batch["user_idx"].long()
                item = batch["item_idx"].long()
                rating = batch["rating"].float()

                pred = model(user, item)
                loss = F.mse_loss(pred, rating)
                val_losses.append(loss.item())

        avg_val_loss = sum(val_losses) / len(val_losses)

        # Log to stdout
        print(f"[Epoch {epoch}] Train MSE: {avg_train_loss:.4f} | Val MSE: {avg_val_loss:.4f}")

        # ---------- Save Checkpoint (Rank 0 Only) ----------
        if rank == 0:
            out_dir = os.path.join(CKPT_DIR, f"epoch_{epoch}_{uuid.uuid4().hex}")
            os.makedirs(out_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
            torch.save({"epoch": epoch}, os.path.join(out_dir, "meta.pt"))
            ckpt_out = Checkpoint.from_directory(out_dir)
        else:
            ckpt_out = None

        # ---------- Append Metrics to JSON Log (Rank 0) ----------
        if rank == 0:
            logs = []
            if os.path.exists(LOG_PATH):
                try:
                    with open(LOG_PATH, "r") as f:
                        logs = json.load(f)
                except json.JSONDecodeError:
                    print("⚠️  JSON log unreadable. Starting fresh.")
                    logs = []

            logs.append({
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss
            })
            with open(LOG_PATH, "w") as f:
                json.dump(logs, f)

        # ---------- Report to Ray Train ----------
        report({
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss
        }, checkpoint=ckpt_out)

### 09 · Launch Distributed Training with Ray Train  
Now, launch distributed training using `TorchTrainer`, Ray Train’s high-level orchestration interface. Provide it with:

- Your custom `train_loop_per_worker` function
- A `train_config` dictionary that specifies model dimensions, learning rate, and number of epochs
- The sharded `train` and `val` Ray Datasets
- A `ScalingConfig` that sets the number of workers and GPU usage

Also, configure checkpointing and fault tolerance:
- Ray keeps the 3 most recent checkpoints
- Failed workers retry up to two times

Calling `trainer.fit()` kicks off training across the cluster. If any workers fail or disconnect, Ray restarts them and resume from the latest checkpoint.

In [None]:
# 09. Launch Distributed Training with Ray TorchTrainer

# Define config params
train_config = {
    "num_users": df["user_idx"].nunique(),
    "num_items": df["item_idx"].nunique(),
    "embedding_dim": 64,
    "lr": 1e-3,
    "epochs": 20,
}

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_config,
    scaling_config=ScalingConfig(
        num_workers=8,       # Increase as needed
        use_gpu=True         # Set to True if training on GPUs
    ),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="mf_ray_train",
        storage_path="/mnt/cluster_storage/rec_sys_tutorial/results",
        checkpoint_config=CheckpointConfig(num_to_keep=3),
        failure_config=FailureConfig(max_failures=2)
    )
)

# Run distributed training
result = trainer.fit()

### 10 · Plot Train and Validation Loss Curves  
After training, load the logged epoch metrics from the shared JSON file and plot the train and validation loss curves.

This visualization helps you evaluate model behavior across epochs, whether it’s under-fitting, over-fitting, or converging steadily. You compute both curves using MSE (Mean Squared Error), which is the same loss function used during training.

Plotting these curves also serves as a sanity check to ensure that checkpointing, logging, and training progressed as expected.

In [None]:
# 10. Plot Train/Val Loss Curves

# Path to training metrics log
LOG_PATH = "/mnt/cluster_storage/rec_sys_tutorial/epoch_metrics.json"

# Load and convert to DataFrame
with open(LOG_PATH, "r") as f:
    logs = json.load(f)

df = pd.DataFrame(logs)
df["train_loss"] = pd.to_numeric(df["train_loss"], errors="coerce")
df["val_loss"] = pd.to_numeric(df["val_loss"], errors="coerce")

# Plot
plt.figure(figsize=(7, 4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Matrix Factorization - Loss per Epoch")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()