In [57]:
import wandb
import re
from pprint import pprint
import pandas as pd

wandb.login()
api = wandb.Api()
experiments = api.runs("rom42pla_team/noisy_eeg")

In [62]:
table_rows = []
for experiment in experiments:
    date, hour, dataset, validation, signal_type, model = experiment.name.split(
        "_")
    df_experiment = experiment.history()
    runs = {col.split("-")[0] for col in df_experiment.columns if re.fullmatch(r"run_[0-9]+-.*", col)}
    avg_rows = []
    for run in runs:
        # prints all columns with the metrics of the run
        metrics = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/.*", col)}
        # get the results for the run
        metrics_val = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/val", col)}
        # gets the subset of dataframe for the run
        df_run = df_experiment[sorted(metrics_val)].dropna().reset_index(drop=True)
        df_run = df_run.rename(columns={col: col.split("-")[-1].split("/")[0] for col in df_run})
        # drop columns that end with "loss"
        df_run = df_run.loc[:, ~df_run.columns.str.endswith("loss")]
        # gets the best run based on the cls accuracy
        df_run = df_run.replace("NaN", None)
        best_row_i = df_run["cls_acc"].idxmax()
        best_row = df_run.iloc[best_row_i]
        avg_rows.append(best_row)
    df_rows = pd.DataFrame(avg_rows)
    rows_mean, rows_std = df_rows.mean(), df_rows.std()
    for df in [rows_mean, rows_std]:
        df["dataset"] = dataset
        df["validation"] = validation
        df["signal_type"] = (
            "$<$\\SI{100}{\\hertz}" if signal_type == "eeg" else "$>$\\SI{100}{\\hertz}"
        )
        df["model"] = model
    rows = pd.DataFrame({
        key: f"{mean * 100:.3f} ± {std*100:.3f}" if isinstance(mean, float) else mean
        for key, mean, std in zip(rows_mean.index, rows_mean, rows_std)
    }, index=[0]).iloc[0]
    table_rows.append(rows)
# merge each run's series
table = pd.DataFrame(table_rows, dtype="object")
# reorder the columns
first_cols = [
    "dataset",
    "model",
    "validation",
    "signal_type",
]
table = table.sort_values(first_cols)
table = table[first_cols +
              [col for col in table.columns if col not in first_cols]]
# rename the columns
table = table.rename(columns={
    "dataset": "Dataset", 
    "signal_type": "Frequencies", 
    "validation": "Validation", 
    "model": "Model",
    "cls_acc": "cls Accuracy (\\%) $\\uparrow$",
    "cls_f1": "cls $F_1$ (\\%) $\\uparrow$",
    "ids_acc": "ids Accuracy (\\%) $\\uparrow$",
    "ids_f1": "ids $F_1$ (\\%) $\\uparrow$",
    })

In [65]:
for dataset in table["Dataset"].unique():
    dataset_table = table[table["Dataset"] == dataset]
    dataset_table = dataset_table.drop(columns=["Dataset"])
    # general table
    print(f"TABLE FOR {dataset}:")
    dataset_table = dataset_table.sort_values(by=["Model", "Validation", "Frequencies"])
    # dataset_table = dataset_table[["Model", "Validation", "Frequencies"]]
    print(
        dataset_table.to_latex(na_rep="-", index=False, multirow=True, multicolumn=True)
    )

TABLE FOR deap:
\begin{tabular}{lllllllllllllll}
\toprule
Model & Validation & Frequencies & cls Accuracy (\%) $\uparrow$ & cls $F_1$ (\%) $\uparrow$ & ids Accuracy (\%) $\uparrow$ & ids $F_1$ (\%) $\uparrow$ & cls_label=arousal_acc & cls_label=arousal_f1 & cls_label=dominance_acc & cls_label=dominance_f1 & cls_label=liking_acc & cls_label=liking_f1 & cls_label=valence_acc & cls_label=valence_f1 \\
\midrule
dino & kfold & $<$\SI{100}{\hertz} & 99.894 ± 0.042 & 99.908 ± 0.035 & 99.971 ± 0.033 & 99.971 ± 0.033 & - & - & - & - & - & - & - & - \\
dino & kfold & $>$\SI{100}{\hertz} & 99.898 ± 0.043 & 99.912 ± 0.036 & 99.967 ± 0.036 & 99.967 ± 0.036 & - & - & - & - & - & - & - & - \\
dino & loso & $<$\SI{100}{\hertz} & 56.626 ± 5.341 & 61.929 ± 11.087 & - & - & - & - & - & - & - & - & - & - \\
dino & loso & $>$\SI{100}{\hertz} & 56.120 ± 4.671 & 61.753 ± 9.120 & - & - & - & - & - & - & - & - & - & - \\
edpnet & kfold & $<$\SI{100}{\hertz} & 76.947 ± 0.831 & 81.221 ± 0.660 & 99.830 ± 0.082 & 