In [66]:
import wandb
import re
from pprint import pprint
import pandas as pd

api = wandb.Api()
experiments = api.runs("rom42pla_team/noisy_eeg")

In [112]:
table_rows = []
for experiment in experiments:
    date, hour, dataset, validation, signal_type, model = experiment.name.split(
        "_")
    df_experiment = experiment.history()
    runs = {col.split("-")[0] for col in df_experiment.columns if re.fullmatch(r"run_[0-9]+-.*", col)}
    avg_rows = []
    for run in runs:
        # prints all columns with the metrics of the run
        metrics = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/.*", col)}
        # get the results for the run
        metrics_val = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/val", col)}
        # gets the subset of dataframe for the run
        df_run = df_experiment[sorted(metrics_val)].dropna().reset_index(drop=True)
        df_run = df_run.rename(columns={col: col.split("-")[-1].split("/")[0] for col in df_run})
        # gets the best run based on the cls loss
        best_row_i = df_run["cls_loss"].idxmin()
        best_row = df_run.iloc[best_row_i]
        avg_rows.append(best_row)
    df_rows = pd.DataFrame(avg_rows)
    rows_mean, rows_std = df_rows.mean(), df_rows.std()
    for df in [rows_mean, rows_std]:
        df["dataset"] = dataset
        df["validation"] = validation
        df["signal_type"] = "<100Hz" if signal_type == "eeg" else ">100Hz"
        df["model"] = model
    rows = pd.DataFrame({
        key: f"{mean:.5f} ± {std:.5f}" if isinstance(mean, float) else mean
        for key, mean, std in zip(rows_mean.index, rows_mean, rows_std)
    }, index=[0]).iloc[0]
    table_rows.append(rows)
# merge each run's series
table = pd.DataFrame(table_rows, dtype="object")
# reorder the columns
first_cols = ["dataset", "signal_type", "validation", "model"]
table = table.sort_values(first_cols)
table = table[first_cols +
              [col for col in table.columns if col not in first_cols]]

In [113]:
print(table.to_latex(na_rep="-", index=False))

\begin{tabular}{llllllllll}
\toprule
dataset & signal_type & validation & model & cls_acc & cls_f1 & cls_loss & ids_acc & ids_f1 & ids_loss \\
\midrule
deap & <100Hz & kfold & dino & 0.99880 ± 0.00059 & 0.99897 ± 0.00050 & 0.00433 ± 0.00217 & 0.99975 ± 0.00034 & 0.99975 ± 0.00034 & 0.00138 ± 0.00171 \\
deap & <100Hz & kfold & linear & 0.66379 ± 0.00744 & 0.74116 ± 0.01020 & 0.61167 ± 0.00419 & 0.99649 ± 0.00112 & 0.99649 ± 0.00112 & 0.03828 ± 0.00613 \\
deap & <100Hz & kfold & mlp & 0.83880 ± 0.00962 & 0.86416 ± 0.00758 & 0.36379 ± 0.01643 & 0.99801 ± 0.00116 & 0.99801 ± 0.00116 & 0.01041 ± 0.00600 \\
deap & <100Hz & loso & dino & 0.55558 ± 0.05777 & 0.61426 ± 0.09481 & 1.10795 ± 0.51627 & - & - & - \\
deap & <100Hz & loso & mlp & 0.57012 ± 0.05748 & 0.63720 ± 0.09930 & 0.72073 ± 0.06334 & - & - & - \\
deap & >100Hz & kfold & dino & 0.99890 ± 0.00051 & 0.99906 ± 0.00043 & 0.00438 ± 0.00270 & 0.99975 ± 0.00030 & 0.99975 ± 0.00030 & 0.00136 ± 0.00153 \\
deap & >100Hz & kfold & linear & 0