## 13. Plot training and validation loss curves  

After training completes, visualize the recorded metrics directly from Ray Train’s results object. No manual CSV handling is required.  
`result.metrics_dataframe` automatically contains every metric reported during training, including per-epoch loss values from all workers.  

This plot extracts the training and validation losses, groups them by epoch, and displays the most recent report for each.  
By comparing these two curves, you quickly assess convergence behavior and detect overfitting (for example, when training loss continues to decrease while validation loss rises).  

Because Ray Train automatically stores all metrics and checkpoints, this visualization reflects the same information used to select the **best checkpoint** based on validation loss in your `RunConfig`.


In [None]:
# 13. Plot training / validation loss curves 

# Pull the full metrics history Ray stored for this run
df = result.metrics_dataframe.copy()

# Keep only the columns we need (guard against extra columns)
cols = [c for c in ["epoch", "train_loss", "val_loss"] if c in df.columns]
df = df[cols].dropna()

# If multiple rows per epoch exist, keep the last report per epoch
if "epoch" in df.columns:
    df = df.sort_index().groupby("epoch", as_index=False).last()

# Plot
plt.figure(figsize=(8, 5))
if "train_loss" in df.columns:
    plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train Loss")
if "val_loss" in df.columns:
    plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train/Val Loss across Epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()