## 10. Ray Train training loop (with teacher forcing)  
This is the heart of Ray Train. Each worker executes this loop independently, but Ray orchestrates everything from checkpointing to failure recovery. Include teacher forcing, feeding the shifted ground-truth to the decoder, which allows the model to learn more quickly than starting from zero. Also log training and validation loss per epoch and save checkpoints.

In [None]:
# 10. Ray Train train_loop_per_worker with checkpointing, teacher forcing, and clean structure

def train_loop_per_worker(config):
    import tempfile
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from ray import train
    from ray.train import Checkpoint, get_context
    from ray.train import get_checkpoint

    torch.manual_seed(0)

    # ─────────────────────────────────────────────────────────────
    # 1) Model (DDP-prepared)
    # ─────────────────────────────────────────────────────────────
    model = TimeSeriesTransformer(
        input_window=INPUT_WINDOW,
        horizon=HORIZON,
        d_model=config["d_model"],
        nhead=config["nhead"],
        num_layers=config["num_layers"],
    )
    model = train.torch.prepare_model(model)

    # ─────────────────────────────────────────────────────────────
    # 2) Optimizer / Loss
    # ─────────────────────────────────────────────────────────────
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    loss_fn = nn.SmoothL1Loss()

    # ─────────────────────────────────────────────────────────────
    # 3) Resume from checkpoint (if provided by Ray)
    # ─────────────────────────────────────────────────────────────
    rank = get_context().get_world_rank()
    start_epoch = 0
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as ckpt_dir:
            # Safe CPU load in case of device mismatch on resume
            model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu"))
            opt_state_path = os.path.join(ckpt_dir, "optim.pt")
            if os.path.exists(opt_state_path):
                optimizer.load_state_dict(torch.load(opt_state_path, map_location="cpu"))
            meta = torch.load(os.path.join(ckpt_dir, "meta.pt"))
            start_epoch = int(meta.get("epoch", -1)) + 1
        if rank == 0:
            print(f"[Rank {rank}] ✅ Resumed from checkpoint at epoch {start_epoch}")

    # ─────────────────────────────────────────────────────────────
    # 4) Dataloaders for this worker
    # ─────────────────────────────────────────────────────────────
    train_loader = build_dataloader(
        os.path.join(PARQUET_DIR, "train.parquet"),
        batch_size=config["bs"],
        shuffle=True,
    )
    val_loader = build_dataloader(
        os.path.join(PARQUET_DIR, "val.parquet"),
        batch_size=config["bs"],
        shuffle=False,
    )

    # ─────────────────────────────────────────────────────────────
    # 5) Epoch loop
    # ─────────────────────────────────────────────────────────────
    for epoch in range(start_epoch, config["epochs"]):
        # ---- Train ----
        model.train()
        train_loss_sum = 0.0
        for past, future in train_loader:
            optimizer.zero_grad()

            # Teacher forcing: shift future targets to use as decoder input
            future = future.unsqueeze(-1)                                # (B, F, 1)
            start_token = torch.zeros_like(future[:, :1])                # (B, 1, 1)
            decoder_input = torch.cat([start_token, future[:, :-1]], 1)  # (B, F, 1)

            pred = model(past, decoder_input)                            # (B, F)
            loss = loss_fn(pred, future.squeeze(-1))                     # (B, F) vs (B, F)

            loss.backward()
            optimizer.step()
            train_loss_sum += float(loss.item())

        avg_train_loss = train_loss_sum / max(1, len(train_loader))

        # ---- Validate ----
        model.eval()
        val_loss_sum = 0.0
        with torch.no_grad():
            for past, future in val_loader:
                pred = model(past)                                       # zeros-as-decoder-input path
                loss = loss_fn(pred, future)
                val_loss_sum += float(loss.item())
        avg_val_loss = val_loss_sum / max(1, len(val_loader))

        if rank == 0:
            print({"epoch": epoch, "train_loss": avg_train_loss, "val_loss": avg_val_loss})

        metrics = {
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
        }

        # ─────────────────────────────────────────────────────────────
        # 6) Report + temp checkpoint (rank 0 attaches; others metrics-only)
        # ─────────────────────────────────────────────────────────────
        if rank == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(model.state_dict(), os.path.join(tmpdir, "model.pt"))
                torch.save(optimizer.state_dict(), os.path.join(tmpdir, "optim.pt"))
                torch.save({"epoch": epoch}, os.path.join(tmpdir, "meta.pt"))
                ckpt_out = Checkpoint.from_directory(tmpdir)
                train.report(metrics, checkpoint=ckpt_out)
        else:
            train.report(metrics, checkpoint=None)


### 11. Launch training on 8 GPUs  
Construct a `TorchTrainer` and run it. Ray automatically distributes the model across 8 GPUs, prepares the datasets for each worker, and starts training. Also configure checkpointing to retain the top-performing models and set failure recovery to 3 attempts.

In [None]:
# 11. Launch training

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"lr": 1e-3, "bs": 4, "epochs": 20,
                       "d_model": 128, "nhead": 4, "num_layers": 3},
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    run_config=RunConfig(
        name="nyc_taxi_transformer",
        storage_path=os.path.join(DATA_DIR, "results"),  
        checkpoint_config=CheckpointConfig(
            num_to_keep=20,
            # Let your loop decide when to checkpoint (each epoch). Scoring still applies.
            checkpoint_score_attribute="val_loss",
            checkpoint_score_order="min",
            # (Optional) If you want the last epoch’s checkpoint regardless of score:
            # checkpoint_at_end=True,
        ),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Final metrics:", result.metrics)

# Best checkpoint (by val_loss) thanks to checkpoint_score_* above:
best_ckpt = result.checkpoint

### 12. Plot training and validation loss  
After training, visualize the saved metrics to assess whether the model is over-fitting, under-fitting, or improving steadily. A healthy curve shows decreasing train and validation loss, with convergence over time. This diagnostic is especially useful when comparing different model configurations. In this tutorial, you aren't using substantial amounts of data, so you see the validation curve remains primarily stagnant.

In [None]:
# 12. Plot train/val loss curves (from Ray Train results)

# Pull full metrics history Ray stored for this run
df = result.metrics_dataframe.copy()

# Keep only relevant columns (defensive in case Ray adds extras)
cols = [c for c in ["epoch", "train_loss", "val_loss"] if c in df.columns]
df = df[cols].dropna()

# If multiple reports per epoch exist, keep the latest one
if "epoch" in df.columns:
    df = df.sort_index().groupby("epoch", as_index=False).last()

# Plot
plt.figure(figsize=(7, 4))
if "train_loss" in df.columns:
    plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
if "val_loss" in df.columns:
    plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")

plt.xlabel("Epoch")
plt.ylabel("SmoothL1 Loss")
plt.title("TimeSeriesTransformer — Train vs. Val Loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()