In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# --- Settings --- #
HISTORY_PATH = "results/step_1/tune_full_history.csv"
RESULTS_PATH = "results/step_1/tune_results.csv"
PLOTS_DIR = "results/plots"
TOP_N        = 10 

# --- Load data --- #
full_df = pd.read_csv(HISTORY_PATH)
results_df = pd.read_csv(RESULTS_PATH)

# --- Get column names (reported by train_fn via train.report) --- #
# Expected columns: training_iteration, val_loss, best_val_loss, train_loss,
#                   val_loss_loc, val_loss_str, train_loss_loc, train_loss_str

os.makedirs(PLOTS_DIR, exist_ok=True)

# --- Plot per trial --- #
for rank, (_, row) in enumerate(results_df.sort_values("best_val_loss", ascending=True).head(TOP_N).iterrows(), start=1):
    trial_id = "fn_" + os.path.basename(row["logdir"])
    trial_df = full_df[full_df["trial_id"] == trial_id].sort_values("training_iteration")

    if trial_df.empty:
        print(f"No history found for trial {trial_id}, skipping.")
        continue

    # --- Extract metrics --- #
    epochs          = trial_df["training_iteration"].values
    train_loss      = trial_df["train_loss"].values
    val_loss        = trial_df["val_loss"].values
    train_loss_loc  = trial_df["train_loss_loc"].values
    val_loss_loc    = trial_df["val_loss_loc"].values
    train_loss_str  = trial_df["train_loss_str"].values
    val_loss_str    = trial_df["val_loss_str"].values
    best_val_loss   = trial_df["best_val_loss"].min()
    best_epoch      = trial_df.loc[trial_df["best_val_loss"].idxmin(), "training_iteration"]

    # --- Extract key config values for title --- #
    config_cols = {col.replace("config/", ""): row[col]
                   for col in results_df.columns if col.startswith("config/")}
    title = (f"Rank {rank} | lr={config_cols.get('lr', '?'):.2e} "
             f"| mpnn_layers={int(config_cols.get('mpnn_num_layers', '?'))} "
             f"| attn_layers={int(config_cols.get('attn_num_layers', '?'))} "
             f"| token_dim={int(config_cols.get('token_dim', '?'))} "
             f"| mpnn_dim={int(config_cols.get('mpnn_hidden_dim', '?'))} "
             f"| pool={config_cols.get('pooling_strategy', '?')}")

    # --- Plot --- #
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle(title, fontsize=9)

    # Total loss
    axes[0].plot(epochs, train_loss, label="train")
    axes[0].plot(epochs, val_loss, label="val")
    axes[0].axvline(best_epoch, color="red", linestyle="--", alpha=0.5, label=f"best @ ep{best_epoch}")
    axes[0].set_title("Total Loss")
    axes[0].legend()

    # Location loss
    axes[1].plot(epochs, train_loss_loc, label="train")
    axes[1].plot(epochs, val_loss_loc, label="val")
    axes[1].axvline(best_epoch, color="red", linestyle="--", alpha=0.5)
    axes[1].set_title("Location Loss")
    axes[1].legend()

    # Strength loss
    axes[2].plot(epochs, train_loss_str, label="train")
    axes[2].plot(epochs, val_loss_str, label="val")
    axes[2].axvline(best_epoch, color="red", linestyle="--", alpha=0.5)
    axes[2].set_title("Strength Loss")
    axes[2].legend()

    for ax in axes:
        ax.set_xlabel("Epoch")
        ax.set_ylabel("MSE")
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{PLOTS_DIR}/rank_{rank:02d}_{trial_id}.png", dpi=150, bbox_inches="tight")
    plt.close()


# --- Plot all trials on one plot --- #
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle("All Trials Overview", fontsize=11)

for rank, (_, row) in enumerate(results_df.sort_values("best_val_loss", ascending=True).iterrows(), start=1):
    trial_id = "fn_" + os.path.basename(row["logdir"])
    trial_df = full_df[full_df["trial_id"] == trial_id].sort_values("training_iteration")
    if trial_df.empty:
        continue
    epochs         = trial_df["training_iteration"].values
    label          = f"Rank {rank}"
    axes[0, 0].plot(epochs, trial_df["train_loss"].values,     label=label)
    axes[0, 1].plot(epochs, trial_df["train_loss_loc"].values, label=label)
    axes[0, 2].plot(epochs, trial_df["train_loss_str"].values, label=label)
    axes[1, 0].plot(epochs, trial_df["val_loss"].values,       label=label)
    axes[1, 1].plot(epochs, trial_df["val_loss_loc"].values,   label=label)
    axes[1, 2].plot(epochs, trial_df["val_loss_str"].values,   label=label)

# --- Titles --- #
axes[0, 0].set_title("Train Total Loss")
axes[0, 1].set_title("Train Location Loss")
axes[0, 2].set_title("Train Strength Loss")
axes[1, 0].set_title("Val Total Loss")
axes[1, 1].set_title("Val Location Loss")
axes[1, 2].set_title("Val Strength Loss")

for ax in axes.flat:
    ax.set_xlabel("Epoch")
    ax.set_ylabel("MSE")
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=7)

plt.tight_layout()
plt.savefig(f"{PLOTS_DIR}/all_trials_overview.png", dpi=150, bbox_inches="tight")
plt.close()
print(f"Overview plot saved to {PLOTS_DIR}/all_trials_overview.png")

print(f"Rank {rank:2d} | best_val_loss: {best_val_loss:.6f} @ ep{best_epoch} "
        f"| epochs after best: {len(epochs) - best_epoch} "
        f"| trial: {trial_id}")

print(f"\nAll plots saved to {PLOTS_DIR}/")

# --- Summary table of top configs --- #
print("\n--- Top 10 Configs ---")
summary_cols = (["train_loss","val_loss", "best_val_loss"] +
                [col for col in results_df.columns if col.startswith("config/")
                 and col.replace("config/", "") in
                 ["lr", "mpnn_num_layers", "attn_num_layers", "attn_num_heads",
                  "token_dim", "mpnn_hidden_dim", "pooling_strategy"]])
print(results_df.sort_values("best_val_loss", ascending=True).head(TOP_N)[summary_cols].to_string(index=False))


  plt.tight_layout()


Overview plot saved to results/plots/all_trials_overview.png
Rank 80 | best_val_loss: 0.008861 @ ep86 | epochs after best: -61 | trial: fn_5520cc37

All plots saved to results/plots/

--- Top 10 Configs ---
 train_loss  val_loss  best_val_loss  config/lr  config/mpnn_hidden_dim  config/mpnn_num_layers  config/attn_num_heads  config/attn_num_layers  config/token_dim config/pooling_strategy
   0.004089  0.009367       0.008326   0.000989                      64                       4                      8                       6                64            mean_pooling
   0.005043  0.009540       0.008418   0.001347                      64                       4                      8                       6                64            mean_pooling
   0.004403  0.010603       0.008557   0.001125                      64                       4                      8                       6                64            mean_pooling
   0.003109  0.009435       0.008591   0.001290      

In [None]:
import pandas as pd
import os

full_df = pd.read_csv("results/step_1/tune_full_history.csv")
results_df = pd.read_csv("results/step_1/tune_results.csv")

# --- Check what trial_ids look like in each file --- #
print("=== trial_ids in full_history ===")
print(full_df["trial_id"].unique())

print("\n=== logdir basenames in results ===")
print(results_df["logdir"].apply(lambda x: os.path.basename(x)).unique())

print("\n=== raw logdir values ===")
print(results_df["logdir"].unique())

=== trial_ids in full_history ===
<ArrowStringArray>
['fn_51912113', 'fn_c4fac421', 'fn_f2826e90']
Length: 3, dtype: str

=== logdir basenames in results ===
<ArrowStringArray>
['51912113', 'c4fac421', 'f2826e90']
Length: 3, dtype: str

=== raw logdir values ===
<ArrowStringArray>
['51912113', 'c4fac421', 'f2826e90']
Length: 3, dtype: str
