In [42]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [43]:
font_size = 13
figsize = (6, 3)
rcParams = {
    "font.size": font_size,
    "font.family": "serif",
    "font.serif": "Times New Roman",
}
sns.set_theme(context="notebook", style="whitegrid", palette="colorblind", rc=rcParams)

In [44]:
def mean_corr(df: pd.DataFrame):
    return df.values[np.triu_indices_from(df, k=1)].mean()


def std_corr(df: pd.DataFrame):
    return df.values[np.triu_indices_from(df, k=1)].std()


def read_df(path) -> pd.DataFrame:
    return pd.read_csv(path, index_col=0)


def clean_index_columns(df: pd.DataFrame) -> pd.DataFrame:
    for suffix in ["-Seed", "-Model", "-Optim"]:
        df.columns = df.columns.str.replace(suffix, "")
        df.index = df.index.str.replace(suffix, "")
    return df


def plot_matrix(df: pd.DataFrame, variation: str, dataset: str) -> None:
    def _format_corr(corr: float) -> str:
        if corr == 1:
            return "1"
        return f"{corr:.3f}".lstrip("0")

    fig = plt.figure(figsize=figsize)
    sns.heatmap(
        df,
        annot=df.map(_format_corr),
        fmt="",
        cbar=False,
        cmap="crest",
    )
    fig.gca().set_xticklabels(df.columns, rotation=0, fontsize=font_size)
    fig.gca().set_yticklabels(df.columns, rotation=25, fontsize=font_size)
    plt.tight_layout(pad=0)
    os.makedirs("results/figures", exist_ok=True)
    fig.savefig(
        f"results/figures/4_b_4_difficulty_agreement_{dataset}_{variation}.pdf", dpi=300
    )
    plt.close(fig)

In [45]:
records = []
for dataset in ["cifar", "dcase"]:
    for variation in ["seed", "model", "optim"]:
        df = read_df(f"results/{dataset}/curriculum/{variation}_macro.csv")
        df = clean_index_columns(df)
        plot_matrix(df, variation, dataset)
        records.append(
            {
                "Dataset": dataset,
                "Variation": variation,
                "Mean": mean_corr(df),
                "Std": std_corr(df),
            }
        )
os.makedirs("results/tables", exist_ok=True)
pd.DataFrame(records).to_csv(
    "results/tables/4_b_4_difficulty_agreement.csv", index=False
)

In [46]:
pd.DataFrame(records).set_index(["Dataset", "Variation"]).round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Mean,Std
Dataset,Variation,Unnamed: 2_level_1,Unnamed: 3_level_1
cifar,seed,0.724,0.097
cifar,model,0.673,0.085
cifar,optim,0.732,0.112
dcase,seed,0.689,0.129
dcase,model,0.652,0.132
dcase,optim,0.724,0.08
