## Data Aggregation

We now have a directory full of files containing pandas dataframes. We need to merge this into one big dataframe, to then calculate pairwise Error Consistency Values.

In [None]:
# making sure that updates to imported files are immediately available without restarting the kernel
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import pickle
from os.path import join as pjoin
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append(os.path.abspath(".."))
from utils import fast_cohen

In [None]:
root_dir = "/mnt/lustre/work/bethge/tklein16/projects/ec2/train_out/results/evaluations"
runs = ["copper-brook-23", "honest-planet-24"]

_, dirs, files = next(os.walk(root_dir))
files = sorted(files)
files = [f for f in files if any([run in f for run in runs])]
df_list = []
for file in files:
    if file.endswith(".pd.pkl"):
        with open(pjoin(root_dir, file), "rb") as f:
            df = pickle.load(f)
        run_name = file.split("_")[0]
        epoch = file.split("_")[1].split(".")[0]
        df["Correct"] = df["Prediction"] == df["Label"]
        df["Run"] = run_name
        int_epoch = int(epoch) if not epoch in ["initial", "final"] else -1
        int_epoch = int_epoch if epoch != "final" else 90
        df["Epoch"] = int_epoch
        df_list.append(df)

all_df = pd.concat(df_list)
display(all_df)

In [None]:
# figure 1: make a plot of the ec over training for the two runs


def get_ecs_over_training(all_df, run_name, end_epoch) -> pd.DataFrame:

    # lists to hold the intermediate results
    epochs = []
    ecs = []
    accuracies = []

    run_df = all_df[all_df["Run"] == run_name]
    ref_df = run_df[run_df["Epoch"] == end_epoch]

    for epoch, df in run_df.groupby("Epoch"):

        merged_df = pd.merge(ref_df, df, on="Path", how="inner")

        acc1 = merged_df["Correct_x"].mean()
        acc2 = merged_df["Correct_y"].mean()

        if acc1 == 1.0 or acc2 == 1.0 or acc1 == 0.0 or acc2 == 0.0:
            print(f"Perfect accuracy or no correct responses found for {epoch}!")
            continue

        real_trials_1 = merged_df["Correct_x"].to_numpy()
        real_trials_2 = merged_df["Correct_y"].to_numpy()

        ec = fast_cohen(real_trials_1, real_trials_2)

        epochs.append(epoch)
        ecs.append(ec)
        accuracies.append(acc2 * 100)

    df = pd.DataFrame(
        {"Epoch": epochs, "EC": ecs, "Run": run_name, "Accuracy": accuracies}
    )

    return df


self_df = pd.concat([get_ecs_over_training(all_df, run, 90) for run in runs])

In [None]:
def get_pairwise_ecs_over_training(all_df) -> pd.DataFrame:

    # lists to hold the intermediate results
    epochs = []
    ecs = []

    for epoch, df in all_df.groupby("Epoch"):

        dfs = [(run1, df1) for run1, df1 in df.groupby("Run")]

        merged_df = pd.merge(dfs[0][1], dfs[1][1], on="Path", how="inner")

        acc1 = merged_df["Correct_x"].mean()
        acc2 = merged_df["Correct_y"].mean()

        if acc1 == 1.0 or acc2 == 1.0 or acc1 == 0.0 or acc2 == 0.0:
            print(f"Perfect accuracy or no correct responses found for {epoch}!")
            continue

        real_trials_1 = merged_df["Correct_x"].to_numpy()
        real_trials_2 = merged_df["Correct_y"].to_numpy()

        ec = fast_cohen(real_trials_1, real_trials_2)

        epochs.append(epoch)
        ecs.append(ec)

    df = pd.DataFrame(
        {
            "Epoch": epochs,
            "EC": ecs,
        }
    )

    return df


pairwise_df = get_pairwise_ecs_over_training(all_df)

In [None]:
def plot_pairwise_ec_over_training(df, ax):
    ax = sns.lineplot(data=df, x="Epoch", y="EC", label="Pairwise", ax=ax)
    return ax

In [None]:
def plot_self_ec_over_training(df, ax):
    ax = sns.lineplot(data=df, x="Epoch", y="EC", hue="Run", palette="flare", ax=ax)
    return ax


def plot_accuracy_over_training(df, ax):
    ax2 = ax.twinx()
    sns.lineplot(
        data=df, x="Epoch", y="Accuracy", label="Accuracy", color="green", ax=ax2
    )
    sns.move_legend(ax2, "lower right")
    ax2.set_ylabel("Accuracy [%]")

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

ax = plot_pairwise_ec_over_training(pairwise_df, ax)
ax = plot_self_ec_over_training(self_df, ax)
plot_accuracy_over_training(self_df, ax)
ax.grid(axis="y")
sns.despine(right=False)
ax.set_title("(Self-) Error Consistency over Training")
ax.set_xlabel("Epoch")
ax.set_ylabel("Kappa")
plt.tight_layout()
plt.savefig("../figures/ec_over_training.pdf", bbox_inches="tight")
plt.show()