In [1]:
import wandb
import re
from pprint import pprint
import pandas as pd

wandb.login()
api = wandb.Api()
experiments = api.runs("rom42pla_team/noisy_eeg")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrom42pla[0m ([33mrom42pla_team[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [60]:
table_rows = []
for experiment in experiments:
    date, hour, dataset, validation, signal_type, model = experiment.name.split(
        "_")
    df_experiment = experiment.history()
    runs = {col.split("-")[0] for col in df_experiment.columns if re.fullmatch(r"run_[0-9]+-.*", col)}
    avg_rows = []
    for run in runs:
        # prints all columns with the metrics of the run
        metrics = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/.*", col)}
        # get the results for the run
        metrics_val = {col for col in df_experiment.columns if re.fullmatch(f"{run}-.*/val", col)}
        # gets the subset of dataframe for the run
        df_run = df_experiment[sorted(metrics_val)].dropna().reset_index(drop=True)
        df_run = df_run.rename(columns={col: col.split("-")[-1].split("/")[0] for col in df_run})
        # drop columns that end with "loss"
        df_run = df_run.loc[:, ~df_run.columns.str.endswith("loss")]
        # gets the best run based on the cls accuracy
        df_run = df_run.replace("NaN", None)
        best_row_i = df_run["cls_acc"].idxmax()
        best_row = df_run.iloc[best_row_i]
        avg_rows.append(best_row)
    df_rows = pd.DataFrame(avg_rows)
    rows_mean, rows_std = df_rows.mean(), df_rows.std()
    models_renamed = {
        "linear": "Linear",
        "mlp": "MLP",
        "eegnet": "EEGNet\cite{eegnet}",
        "edpnet": "EDPNet\cite{edpnet}",
        "dino": "DINOv2\cite{dinov2}",
        "sateer": "SATEER\cite{sateer}",
    }
    validation_renamed = {
        "kfold": "$k$-fold",
        "loso": "LOSO"
    }
    for df in [rows_mean, rows_std]:
        df["dataset"] = dataset
        df["validation"] = validation_renamed[validation]
        df["signal_type"] = (
            "$<$\\SI{100}{\\hertz}" if signal_type == "eeg" else "$>$\\SI{100}{\\hertz}"
        )
        df["model"] = models_renamed[model]
    rows = pd.DataFrame({
        key: f"{mean * 100:.1f} ± {std*100:.2f}" if isinstance(mean, float) else mean
        for key, mean, std in zip(rows_mean.index, rows_mean, rows_std)
    }, index=[0]).iloc[0]
    table_rows.append(rows)
# merge each run's series
table = pd.DataFrame(table_rows, dtype="object")
# reorder the columns
first_cols = [
    "dataset",
    "model",
    "validation",
    "signal_type",
]
table = table.sort_values(first_cols)
table = table[first_cols +
              [col for col in table.columns if col not in first_cols]]

In [54]:
def print_df_for_latex(df):
    latex_string = df.to_latex(na_rep="-", index=False)
    replacers = {
        "bottomrule": "botrule",
        "midrule": "hline",
    }
    for k, v in replacers.items():
        latex_string = latex_string.replace(k, v)
    print(latex_string)
    return latex_string

In [83]:
for dataset in table["dataset"].unique():
    dataset_table = table[table["dataset"] == dataset]
    dataset_table = dataset_table.drop(columns=["dataset"])
    common_columns_raw = ["model", "validation", "signal_type"]
    dataset_table = dataset_table.sort_values(by=common_columns_raw)
    common_columns_renamed = [
        "Model",
        "Validation",
        "Frequencies",
    ]
    dataset_table = dataset_table.rename(
        columns={
            k: v
            for k, v in zip(common_columns_raw, common_columns_renamed)
        }
    )

    # general table
    print(f"AGGREGATED TABLE FOR {dataset}:")
    agg_table_columns_raw = ["cls_acc", "cls_f1", "ids_acc", "ids_f1"]
    agg_table = dataset_table[common_columns_renamed + agg_table_columns_raw]
    agg_table_columns_renamed = [
        "cls Accuracy (\\%) $\\uparrow$",
        "cls $F_1$ (\\%) $\\uparrow$",
        "ids Accuracy (\\%) $\\uparrow$",
        "ids $F_1$ (\\%) $\\uparrow$",
    ]
    agg_table = agg_table.rename(
        columns={
            k: v
            for k, v in zip(agg_table_columns_raw, agg_table_columns_renamed)
        }
    )
    print_df_for_latex(agg_table)

    # single labels
    print(f"SINGLE LABELS TABLE FOR {dataset}:")
    labels = sorted(set(re.findall(r"cls_label=([a-zA-Z]+)_", " ".join(dataset_table.columns))))
    labels_table_columns_raw = sorted([f"cls_label={label}_{metric}" for metric in ["acc", "f1"] for label in labels])
    assert all([label in dataset_table.columns for label in labels_table_columns_raw])
    labels_table = dataset_table[common_columns_renamed + labels_table_columns_raw]
    metrics_renamed = {"acc": "acc.", "f1": "$F_1$"}
    labels_table_columns_renamed = []
    for column in labels_table_columns_raw:
        _, label, metric = column.split("_")
        label = label.split("=")[-1]
        labels_table_columns_renamed.append(f"{label.capitalize()} {metrics_renamed[metric]}")
    labels_table = labels_table.rename(
        columns={k: v for k, v in zip(labels_table_columns_raw, labels_table_columns_renamed)}
    )
    frequencies = sorted(dataset_table["Frequencies"].unique())
    # since there are too many columns, we split the table into multiple tables
    for metric in sorted(metrics_renamed.values()):
        labels_subtable = labels_table[common_columns_renamed]
        print(labels_subtable)
        raise
        for column in labels_table.columns:
            if metric not in column:
                continue
            for frequency in frequencies:
                labels_subtable.loc[:, f"{label} {frequency}"] = 1
        # labels_subtable = labels_table[common_columns_renamed + [column for column in labels_table if metric in column]]
        print_df_for_latex(labels_subtable)

AGGREGATED TABLE FOR deap:
\begin{tabular}{lllllll}
\toprule
Model & Validation & Frequencies & cls Accuracy (\%) $\uparrow$ & cls $F_1$ (\%) $\uparrow$ & ids Accuracy (\%) $\uparrow$ & ids $F_1$ (\%) $\uparrow$ \\
\hline
DINOv2\cite{dinov2} & $k$-fold & $<$\SI{100}{\hertz} & 99.9 ± 0.05 & 99.9 ± 0.05 & 100.0 ± 0.02 & 100.0 ± 0.02 \\
DINOv2\cite{dinov2} & $k$-fold & $>$\SI{100}{\hertz} & 99.9 ± 0.07 & 99.9 ± 0.05 & 100.0 ± 0.03 & 100.0 ± 0.03 \\
EDPNet\cite{edpnet} & $k$-fold & $<$\SI{100}{\hertz} & 72.2 ± 0.45 & 77.4 ± 0.41 & 99.8 ± 0.08 & 99.8 ± 0.08 \\
EEGNet\cite{eegnet} & $k$-fold & $<$\SI{100}{\hertz} & 61.7 ± 0.71 & 71.6 ± 0.78 & 87.3 ± 5.38 & 87.3 ± 5.38 \\
Linear & $k$-fold & $<$\SI{100}{\hertz} & 66.1 ± 0.62 & 73.4 ± 0.85 & 99.2 ± 0.16 & 99.2 ± 0.16 \\
MLP & $k$-fold & $<$\SI{100}{\hertz} & 77.5 ± 0.79 & 81.3 ± 1.60 & 99.8 ± 0.14 & 99.8 ± 0.14 \\
SATEER\cite{sateer} & $k$-fold & $<$\SI{100}{\hertz} & 94.1 ± 0.49 & 95.0 ± 0.43 & 99.8 ± 0.07 & 99.8 ± 0.07 \\
\botrule
\end{tabul

RuntimeError: No active exception to reraise