Order and summarize training performances of the MLP approach over a grid of hyperparameters.

In [1]:
from pathlib import Path
import yaml

import numpy as np
import pandas as pd

In [12]:
# ------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------
ROOT = Path("multirun/2026-02-08/09-28-33")  # directory containing run_id folders
RUN_IDS = range(0, 252)

HP_KEYS = {
    "batch_size": "batch_size",
    "physical_head.hidden_layers_size": "physical_head.hidden_layers_size",
    "mdn_head.hidden_layers_size": "mdn_head.hidden_layers_size",
    "trainer_module.learning_rate": "trainer_module.learning_rate",
}

METRICS_PATH = "lightning_logs/csv_logs/version_0/metrics.csv"
OVERRIDES_PATH = ".hydra/overrides.yaml"

# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def parse_overrides_yaml(path):
    """
    Parse Hydra overrides.yaml into a flat dict.
    """
    with open(path, "r") as f:
        overrides = yaml.safe_load(f)

    params = {}
    for item in overrides:
        item = item.lstrip("+")  # remove leading +
        if "=" not in item:
            continue
        k, v = item.split("=", 1)
        params[k] = yaml.safe_load(v)  # parse lists, bools, numbers

    return params


def compute_val_overfit(metrics_df):
    """
    Returns:
        min_val
        max_val_after_min
    """
    val_series = metrics_df["val"].dropna().reset_index(drop=True)

    if val_series.empty:
        return np.nan, np.nan

    min_val = val_series.min()
    min_idx = val_series.idxmin()

    # val after minimum (inclusive or exclusive â€“ here exclusive)
    after_min = val_series.iloc[min_idx + 1 :]
    max_after_min = after_min.max() if not after_min.empty else min_val

    return min_val, max_after_min


# ------------------------------------------------------------------
# Collect per-run results
# ------------------------------------------------------------------
rows = []

for run_id in RUN_IDS:
    run_dir = ROOT / str(run_id)

    metrics_file = run_dir / METRICS_PATH
    overrides_file = run_dir / OVERRIDES_PATH

    if not metrics_file.exists() or not overrides_file.exists():
        continue

    # ---- read metrics
    metrics = pd.read_csv(metrics_file)

    min_loss_epoch = metrics["loss_epoch"].min()
    min_val, max_val_after_min = compute_val_overfit(metrics)

    # ---- read hyperparameters
    overrides = parse_overrides_yaml(overrides_file)

    hp_values = {
        out_key: overrides.get(in_key, None)
        for out_key, in_key in HP_KEYS.items()
    }

    rows.append(
        {
            "run_id": run_id,
            "min_loss_epoch": min_loss_epoch,
            "min_val": min_val,
            "max_val_after_min": max_val_after_min,
            **hp_values,
        }
    )

df_runs = pd.DataFrame(rows)


In [13]:
ranking = (
    df_runs
    .sort_values("min_val")
    .reset_index(drop=True)
)

ranking.head(10)


Unnamed: 0,run_id,min_loss_epoch,min_val,max_val_after_min,batch_size,physical_head.hidden_layers_size,mdn_head.hidden_layers_size,trainer_module.learning_rate
0,15,15.183027,15.256106,15.263334,32000,"[128, 64]","[512, 512, 512]",0.0001
1,12,15.189435,15.256913,15.260602,32000,"[128, 64]","[512, 512]",0.0001
2,33,15.192606,15.256991,15.259939,64000,"[128, 64]","[512, 512, 512]",0.0001
3,0,15.190053,15.25702,15.260676,32000,[128],"[512, 512]",0.0001
4,18,15.192271,15.257034,15.261193,64000,[128],"[512, 512]",0.0001
5,9,15.186585,15.257055,15.262091,32000,[64],"[512, 512, 512]",0.0001
6,51,15.195846,15.257409,15.260657,128000,"[128, 64]","[512, 512, 512]",0.0001
7,3,15.18583,15.257415,15.263082,32000,[128],"[512, 512, 512]",0.0001
8,6,15.190191,15.257603,15.262123,32000,[64],"[512, 512]",0.0001
9,30,15.193144,15.257628,15.259387,64000,"[128, 64]","[512, 512]",0.0001


In [14]:
df_runs["physical_head.hidden_layers_size"] = df_runs["physical_head.hidden_layers_size"].apply(
    lambda x: str(x) if isinstance(x, list) else x
)

In [15]:
df_runs["mdn_head.hidden_layers_size"] = df_runs["mdn_head.hidden_layers_size"].apply(
    lambda x: str(x) if isinstance(x, list) else x
)

In [16]:
metrics_of_interest = [
    "min_loss_epoch",
    "min_val",
    "max_val_after_min",
]

summaries = {}

for hp in HP_KEYS.keys():
    summaries[hp] = df_runs.groupby(hp)[metrics_of_interest].agg(["mean", "var"]).sort_values(("min_val", "mean"))


In [19]:
summaries["batch_size"]

Unnamed: 0_level_0,min_loss_epoch,min_loss_epoch,min_val,min_val,max_val_after_min,max_val_after_min
Unnamed: 0_level_1,mean,var,mean,var,mean,var
batch_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
32000,15.207213,0.000366,15.262178,3.2e-05,15.264053,1.8e-05
64000,15.211994,0.000375,15.263812,4.7e-05,15.264901,3.5e-05
128000,15.217061,0.000508,15.266706,9.6e-05,15.267831,7.9e-05


In [20]:
summaries["physical_head.hidden_layers_size"]

Unnamed: 0_level_0,min_loss_epoch,min_loss_epoch,min_val,min_val,max_val_after_min,max_val_after_min
Unnamed: 0_level_1,mean,var,mean,var,mean,var
physical_head.hidden_layers_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
"[128, 64]",15.212138,0.000423,15.263957,6.3e-05,15.265282,4.8e-05
[128],15.211721,0.000443,15.264251,6e-05,15.265709,4.4e-05
[64],15.212409,0.000434,15.264489,6.2e-05,15.265793,4.7e-05


In [21]:
summaries["mdn_head.hidden_layers_size"]

Unnamed: 0_level_0,min_loss_epoch,min_loss_epoch,min_val,min_val,max_val_after_min,max_val_after_min
Unnamed: 0_level_1,mean,var,mean,var,mean,var
mdn_head.hidden_layers_size,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
"[512, 512, 512]",15.210461,0.000416,15.263787,5.3e-05,15.26538,3.7e-05
"[512, 512]",15.213718,0.000428,15.264677,6.8e-05,15.26581,5.5e-05


In [22]:
summaries["trainer_module.learning_rate"]

Unnamed: 0_level_0,min_loss_epoch,min_loss_epoch,min_val,min_val,max_val_after_min,max_val_after_min
Unnamed: 0_level_1,mean,var,mean,var,mean,var
trainer_module.learning_rate,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
0.0001,15.190741,1.2e-05,15.257653,5.353046e-07,15.261308,1.276282e-06
1e-05,15.20734,2.4e-05,15.260829,9.763872e-07,15.261262,7.004104e-07
1e-06,15.238187,3.7e-05,15.274214,2.0441e-05,15.274214,2.0441e-05
