In [1]:
import os
from typing import Tuple

import numpy as np
import pandas as pd

In [2]:
def mean_corr(df: pd.DataFrame):
    return df.values[np.triu_indices_from(df, k=1)].mean()


def std_corr(df: pd.DataFrame):
    return df.values[np.triu_indices_from(df, k=1)].std()


def read_mean_std(path: str, sf: str, subset: str) -> Tuple[float, float]:
    df = pd.read_csv(f"{path}/{subset}_{sf}.csv", index_col=0)
    return mean_corr(df), std_corr(df)


def create_correlation_df(path: str) -> pd.DataFrame:
    rows = []
    for sf in ["ce_loss", "cv_loss", "cum_acc", "fit", "pd"]:
        seed_mean, seed_std = read_mean_std(path, sf, "seed")
        model_mean, model_std = read_mean_std(path, sf, "model")
        optim_mean, optim_std = read_mean_std(path, sf, "optim")
        rows.append(
            [sf, seed_mean, seed_std, model_mean, model_std, optim_mean, optim_std]
        )
    model_mean, model_std = read_mean_std(path, "tt", "model")
    rows.append(["tt", np.nan, np.nan, model_mean, model_std, np.nan, np.nan])
    return pd.DataFrame(
        columns=["sf", "seed", "seed_std", "model", "model_std", "optim", "optim_std"],
        data=rows,
    ).set_index("sf")

In [3]:
cifar_df = create_correlation_df("results/cifar/curriculum")
dcase_df = create_correlation_df("results/dcase/curriculum")
df = pd.concat([cifar_df, dcase_df], keys=["cifar", "dcase"], names=["dataset"])
os.makedirs("results/tables", exist_ok=True)
df.to_csv("results/tables/4_b_1_impact_hparams.csv")

In [4]:
df.round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,seed,seed_std,model,model_std,optim,optim_std
dataset,sf,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
cifar,ce_loss,0.507,0.026,0.428,0.043,0.483,0.055
cifar,cv_loss,0.676,0.007,0.629,0.042,0.688,0.022
cifar,cum_acc,0.76,0.008,0.557,0.101,0.752,0.019
cifar,fit,0.586,0.033,0.416,0.076,0.623,0.019
cifar,pd,0.79,0.012,0.653,0.076,0.799,0.032
cifar,tt,,,0.648,0.025,,
dcase,ce_loss,0.41,0.06,0.415,0.115,0.369,0.041
dcase,cv_loss,0.591,0.018,0.579,0.044,0.556,0.049
dcase,cum_acc,0.821,0.012,0.59,0.099,0.758,0.048
dcase,fit,0.604,0.02,0.475,0.084,0.513,0.052
