In [88]:
import wandb
import re
from pprint import pprint
import pandas as pd

wandb.login()
api = wandb.Api()
experiments = api.runs("rom42pla_team/noisy_eeg")

In [89]:
table_rows = []
for experiment in experiments:
    date, hour, dataset, validation, signal_type, model = experiment.name.split(
        "_")
    df_experiment = experiment.history()
    runs = {col.split("-")[0] for col in df_experiment.columns if re.fullmatch(r"run_[0-9]+-.*", col)}
    avg_rows = []
    for run in runs:
        # prints all columns with the metrics of the run
        metrics = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/.*", col)}
        # get the results for the run
        metrics_val = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/val", col)}
        # gets the subset of dataframe for the run
        df_run = df_experiment[sorted(metrics_val)].dropna().reset_index(drop=True)
        df_run = df_run.rename(columns={col: col.split("-")[-1].split("/")[0] for col in df_run})
        # drop columns that end with "loss"
        df_run = df_run.loc[:, ~df_run.columns.str.endswith("loss")]
        # gets the best run based on the cls accuracy
        df_run = df_run.replace("NaN", None)
        best_row_i = df_run["cls_acc"].idxmax()
        best_row = df_run.iloc[best_row_i]
        avg_rows.append(best_row)
    df_rows = pd.DataFrame(avg_rows)
    rows_mean, rows_std = df_rows.mean(), df_rows.std()
    models_renamed = {
        "linear": "Linear",
        "mlp": "MLP",
        "eegnet": "EEGNet\cite{eegnet}",
        "edpnet": "EDPNet\cite{edpnet}",
        "dino": "DINOv2\cite{dinov2}",
        "sateer": "SATEER\cite{sateer}",
    }
    validation_renamed = {
        "kfold": "$k$-fold",
        "loso": "LOSO"
    }
    for df in [rows_mean, rows_std]:
        df["dataset"] = dataset
        df["validation"] = validation_renamed[validation]
        df["signal_type"] = (
            "$<$\\SI{100}{\\hertz}" if signal_type == "eeg" else "$>$\\SI{100}{\\hertz}"
        )
        df["model"] = models_renamed[model]
    rows = pd.DataFrame({
        key: f"{mean * 100:.1f} ± {std*100:.2f}" if isinstance(mean, float) else mean
        for key, mean, std in zip(rows_mean.index, rows_mean, rows_std)
    }, index=[0]).iloc[0]
    table_rows.append(rows)
# merge each run's series
table = pd.DataFrame(table_rows, dtype="object")
# reorder the columns
first_cols = [
    "dataset",
    "model",
    "validation",
    "signal_type",
]
table = table.sort_values(first_cols)
table = table[first_cols +
              [col for col in table.columns if col not in first_cols]]

In [90]:
def print_df_for_latex(df, title=None):
    latex_string = df.to_latex(na_rep="-", index=False)
    replacers = {
        "bottomrule": "botrule",
        "midrule": "hline",
    }
    for k, v in replacers.items():
        latex_string = latex_string.replace(k, v)
    if title:
        print(title)
    print(latex_string)
    return latex_string

In [91]:
for dataset in table["dataset"].unique():
    dataset_table = table[table["dataset"] == dataset]
    dataset_table = dataset_table.drop(columns=["dataset"])
    common_columns_raw = ["model", "validation", "signal_type"]
    dataset_table = dataset_table.sort_values(by=common_columns_raw)
    common_columns_renamed = [
        "Validation",
        "Model",
        "Frequencies",
    ]
    dataset_table = dataset_table.rename(
        columns={
            k: v
            for k, v in zip(common_columns_raw, common_columns_renamed)
        }
    )

    # general table
    agg_table_columns_raw = ["cls_acc", "cls_f1", "ids_acc", "ids_f1"]
    agg_table = dataset_table[common_columns_renamed + agg_table_columns_raw]
    agg_table_columns_renamed = [
        "cls Accuracy (\\%) $\\uparrow$",
        "cls $F_1$ (\\%) $\\uparrow$",
        "ids Accuracy (\\%) $\\uparrow$",
        "ids $F_1$ (\\%) $\\uparrow$",
    ]
    agg_table = agg_table.rename(
        columns={
            k: v
            for k, v in zip(agg_table_columns_raw, agg_table_columns_renamed)
        }
    )
    # print_df_for_latex(agg_table, title=f"AGGREGATED TABLE 1 FOR {dataset}:")
    # now remove the frequencies column
    agg_table = agg_table.pivot(index=["Model", "Validation"], columns="Frequencies")
    agg_table.columns = [f"{col[0]} {col[1]}" for col in agg_table.columns]
    agg_table = agg_table.reset_index()
    print_df_for_latex(agg_table, title=f"AGGREGATED TABLE FOR {dataset}:")

    # single labels
    labels = sorted(set(re.findall(r"cls_label=([a-zA-Z]+)_", " ".join(dataset_table.columns))))
    labels_table_columns_raw = sorted([f"cls_label={label}_{metric}" for metric in ["acc", "f1"] for label in labels])
    assert all([label in dataset_table.columns for label in labels_table_columns_raw])
    labels_table = dataset_table[common_columns_renamed + labels_table_columns_raw]
    metrics_renamed = {"acc": "acc.", "f1": "$F_1$"}
    labels_table_columns_renamed = []
    for column in labels_table_columns_raw:
        _, label, metric = column.split("_")
        label = label.split("=")[-1]
        labels_table_columns_renamed.append(f"{label.capitalize()} {metrics_renamed[metric]}")
    labels_table = labels_table.rename(
        columns={k: v for k, v in zip(labels_table_columns_raw, labels_table_columns_renamed)}
    )
    frequencies = sorted(dataset_table["Frequencies"].unique())
    # since there are too many columns, we split the table into multiple tables
    for metric in sorted(metrics_renamed.values()):
        # labels_subtable = labels_table[common_columns_renamed]
        columns_per_label = [column for column in labels_table.columns if metric in column]        
        labels_subtable = labels_table[common_columns_renamed + columns_per_label]
        # now remove the frequencies column
        labels_subtable = labels_subtable.pivot(index=["Model", "Validation"], columns="Frequencies")
        labels_subtable.columns = [f"{col[0]} {col[1]}" for col in labels_subtable.columns]
        labels_subtable = labels_subtable.reset_index()
        print_df_for_latex(labels_subtable, f"INDIVIDUAL LABELS TABLES FOR DATASET {dataset}:")

AGGREGATED TABLE FOR deap:
\begin{tabular}{llllllllll}
\toprule
Model & Validation & cls Accuracy (\%) $\uparrow$ $<$\SI{100}{\hertz} & cls Accuracy (\%) $\uparrow$ $>$\SI{100}{\hertz} & cls $F_1$ (\%) $\uparrow$ $<$\SI{100}{\hertz} & cls $F_1$ (\%) $\uparrow$ $>$\SI{100}{\hertz} & ids Accuracy (\%) $\uparrow$ $<$\SI{100}{\hertz} & ids Accuracy (\%) $\uparrow$ $>$\SI{100}{\hertz} & ids $F_1$ (\%) $\uparrow$ $<$\SI{100}{\hertz} & ids $F_1$ (\%) $\uparrow$ $>$\SI{100}{\hertz} \\
\hline
$k$-fold & DINOv2\cite{dinov2} & 99.7 ± 0.09 & 99.8 ± 0.08 & 99.8 ± 0.07 & 99.8 ± 0.07 & 99.9 ± 0.06 & 99.9 ± 0.05 & 99.9 ± 0.06 & 99.9 ± 0.05 \\
$k$-fold & EDPNet\cite{edpnet} & 62.7 ± 0.57 & 68.6 ± 0.88 & 71.3 ± 1.05 & 75.1 ± 0.79 & 98.4 ± 0.39 & 99.4 ± 0.18 & 98.4 ± 0.39 & 99.4 ± 0.18 \\
$k$-fold & EEGNet\cite{eegnet} & 59.9 ± 0.69 & 61.0 ± 0.43 & 71.0 ± 0.93 & 71.6 ± 0.52 & 72.0 ± 15.69 & 96.8 ± 0.48 & 72.0 ± 15.69 & 96.8 ± 0.48 \\
$k$-fold & Linear & 64.4 ± 0.66 & 63.8 ± 0.62 & 72.4 ± 1.05 & 73.1 ± 0.