## 11 · Launch training on 8 GPUs  
Here you construct a `TorchTrainer` and run it. Ray automatically distributes the model across 8 GPUs, prepares the datasets for each worker, and starts training. Also configure checkpointing to retain the top-performing models and set failure recovery to 3 attempts.

In [None]:
# 11. Launch training

os.makedirs(os.path.join(DATA_DIR, "results"), exist_ok=True)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"lr": 1e-3, "bs": 4, "epochs": 20,
                       "d_model": 128, "nhead": 4, "num_layers": 3},
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    run_config=RunConfig(
        name="nyc_taxi_transformer",
        storage_path=os.path.join(DATA_DIR, "results"),
        checkpoint_config=CheckpointConfig(
            num_to_keep=4, checkpoint_frequency=1,
            checkpoint_score_attribute="val_loss", checkpoint_score_order="min"),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Final metrics:", result.metrics)
best_ckpt = result.checkpoint

### 12 · Plot training + validation loss  
After training, visualise the saved `history.csv` to assess whether the model is over-fitting, under-fitting, or improving steadily. A healthy curve shows decreasing train and validation loss, with convergence over time. This diagnostic is especially useful when comparing different model configurations. In this tutorial, You aren't using substantial amounts of data, and so you see the validation curve remains primarily stagnant.

In [None]:
# 12. Plot loss curves 

hist_path = os.path.join(DATA_DIR, "results", "history.csv")

if os.path.exists(hist_path):
    df_hist = pd.read_csv(hist_path, names=["epoch", "train_loss", "val_loss"])
    plt.figure(figsize=(8,4))
    plt.plot(df_hist["epoch"], df_hist["train_loss"], label="Train", marker="o")
    plt.plot(df_hist["epoch"], df_hist["val_loss"], label="Val",   marker="o")
    plt.xlabel("Epoch"); plt.ylabel("MSE Loss"); plt.grid(True); plt.legend()
    plt.title("Train vs. Val Loss"); plt.tight_layout(); plt.show()
else:
    print("No history.csv found. Make sure to log it in the training loop.")