In [None]:
import pandas as pd
import glob
import json
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import seaborn as sns

from utils.save import load_json_array

# Number of epochs online

In [None]:
parser = ArgumentParser()
parser.add_argument(
    "--above-version", default="0", 
    help="include only version folders with number equal or higher than it.")

parsed_args = parser.parse_args([])

In [None]:
path_pattern = "../model/MF/mean-field/version_*/val_dict.json"
json_paths = sorted(glob.glob(path_pattern))

In [None]:
n_epochs_dict = {}

for json_path in json_paths:

    # get version number and name
    version_name = json_path.split("/")[-2]
    version_number = int(version_name.split("_")[-1])
    if version_number < int(parsed_args.above_version):
        continue

    #  get number of epochs
    with open(json_path, 'r') as f:
        val_dict = json.load(f)

    best_epochs = []
    for (i, perf_list) in enumerate(val_dict.values()):
        if i % 2 == 0:
            best_epochs.append(len(perf_list) - 7)
    try:
        plt.plot(range(12, 25), best_epochs, label=version_name)
    except ValueError:
        plt.plot(range(12, 21), best_epochs, label=version_name)
    n_epochs_dict[version_name] = best_epochs 

print(n_epochs_dict)
plt.legend()
plt.savefig("../img/n-epoch-val.svg", bbox_inches="tight")


In [None]:
len(best_epochs)

# BIU results

In [None]:
biu_results_path = "../safebox/TyXe-results-220817.json"
results_df = load_json_array(biu_results_path)

In [None]:
results_df[["loss", "auc", "train_time"]].agg(["mean", "std"]).T

In [None]:
results_df

# BIU ablation results

In [None]:
ablation_results_path = "../safebox/TyXe-ablation-220817.json"
ablation_df = load_json_array(ablation_results_path)
ablation_df

In [None]:
tr_dict

In [None]:
def print_latex(df, column_format="c", **kwargs):
    n_cols = df.shape[1] + 1
    print(df.to_latex(column_format="c" * n_cols, **kwargs))

ablation_table = ablation_df.groupby("ablation")[["loss", "auc", "train_time"]].agg(["mean", "std"])
ablation_table

#round
round_dict = dict(zip(ablation_table.columns, [4] * 4 + [1] * 2))
ablation_table = ablation_table.round(round_dict)

# rename training regimes
# tr_dict = dict(zip([f"i" for i in range(4)], ['BIFT', 'PBT', 'IFT', 'SML']))
# ablation_table = ablation_table.rename(tr_dict).sort_index(0)
ablation_table.index = ['BIFT-init', 'BIFT-NIP+init', 'BIFT', 'BIFT-NIP']
ablation_table = ablation_table.iloc[[2, 3, 0, 1], :]
ablation_table

perf_dict = dict(zip(["loss", "auc", "train_time"], ["NLL", "AUC", "training time"]))
ablation_table = ablation_table.rename(perf_dict, axis=1, level=0)


print_latex(ablation_table.groupby(axis=1, level=0, sort=False).agg(
    # lambda srs: f"{srs.at[0]} + {srs.at[1]}")
    lambda df: df.iloc[:, 0].astype("str").str.cat(df.iloc[:, 1].astype("str"), sep=" ± "))
)


In [None]:
sns.set(font_scale=1.5)
fig, axes = plt.subplots(1, 3, figsize =(18, 6), gridspec_kw=dict(wspace=0.4))
loss_barplot = sns.barplot(x="ablation", y="loss", data=ablation_df, ci="sd", ax=axes[0])
loss_barplot.set_ylim(0.605, 0.629)
auc_barplot = sns.barplot(x="ablation", y="auc", data=ablation_df, ci="sd", ax=axes[1])
auc_barplot.set_ylim(0.69, 0.72)
time_barplot = sns.barplot(x="ablation", y="train_time", data=ablation_df, ci="sd", ax=axes[2])