In [None]:
%load_ext autoreload
%autoreload 2

## Cross Neutralizing

In [None]:
import utils

In [None]:
# NOTE: Change these accordingly
TASKS = ["DEP", "POS"]
TREE_BANKS = ["en_gum", "it_vit", "el_gdt"]
MODELS = ["xlm-roberta-base"]
# Set the values to None if you want to keep all the tags
KEEP_TAGS = {
    "POS": None,
    # "POS": ["NOUN", "ADJ", "VERB", "PRON", "DET", "NUM", "ADV", "AUX"],
    "DEP": None,
    # "DEP": ["PUNCT", "NSUBJ", "OBJ", "OBL", "ADVCL", "CASE", "DET", "AMOD"],
}

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("white")
sns.set_context("paper", font_scale=0.8)

In [None]:
TEXTWIDTH = 6.30045
COLWIDTH = 3.03209

In [None]:
def plot_heatmap(df, save_name=None, vmin=None, vmax=None, center=0.0, cbar=True, annot_kws={"fontsize": 7}):
    bold = {"weight": "bold"}

    plt.figure(figsize=(TEXTWIDTH, TEXTWIDTH), dpi=300)
    cmap = sns.diverging_palette(20, 145, as_cmap=True)
    ax = sns.heatmap(
        df * 100,
        annot=True,
        fmt=".0f",
        cmap=cmap,
        cbar=cbar,
        vmin=vmin,
        vmax=vmax,
        center=center,
        square=True,
        annot_kws=annot_kws,
        # annot_kws={"size": 80 / np.sqrt(len(df))},
    )
    ax.set_xlabel(ax.get_xlabel(), fontdict=bold)
    ax.set_ylabel(ax.get_ylabel(), fontdict=bold)

    if save_name:
        plt.savefig(save_name, bbox_inches='tight')

In [None]:
for TASK in TASKS:
    for MODEL in MODELS:
        for TREE_BANK in TREE_BANKS:
            print(TASK, MODEL, TREE_BANK)
            experiments_df = utils.get_experiments_df(TASK, TREE_BANK, MODEL)
            plot_heatmap(
                experiments_df,
                save_name=f"experiments/{TASK}_{MODEL}_{TREE_BANK}.eps",
                cbar=False,
            )
            MODE = utils.select_best_mode(experiments_df)
            eval_path = f"lightning_logs/{MODEL}/{TREE_BANK}/{TASK}/{MODE}/evaluation"
            acc_drop = utils.get_acc_drop(eval_path, KEEP_TAGS[TASK])
            plot_heatmap(
                acc_drop,
                save_name=f"experiments/{TASK}_{MODEL}_{TREE_BANK}_acc_drop_{MODE}"
                f"{'_sampled' if KEEP_TAGS[TASK] is not None else ''}"
                ".eps",
                vmin=-100,
                vmax=100,
                cbar=False,
            )