In [None]:
%load_ext autoreload
%autoreload 2

## Cross Neutralizing

In [None]:
import re
import utils

In [None]:
# NOTE: Change these accordingly
TASKS = ["DEP", "POS"]
TREE_BANKS = ["en_gum", "it_vit", "el_gdt"]
MODELS = ["xlm-roberta-base", "roberta-base"]
# Set the values to None if you want to keep all the tags
KEEP_TAGS = {
    "POS": None,
    # "POS": ["NOUN", "ADJ", "VERB", "PRON", "DET", "NUM", "ADV", "AUX"],
    "DEP": None,
    # "DEP": ["PUNCT", "NSUBJ", "OBJ", "OBL", "ADVCL", "CASE", "DET", "AMOD"],
}

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set_style("white")
sns.set_context("paper", font_scale=0.8)

In [None]:
TEXTWIDTH = 6.30045
COLWIDTH = 3.03209

In [None]:
def plot_heatmap(
    df,
    save_name=None,
    vmin=None,
    vmax=None,
    center=0.0,
    cbar=True,
    annot_kws={"fontsize": 7},
):
    bold = {"weight": "bold"}

    plt.figure(figsize=(TEXTWIDTH, TEXTWIDTH), dpi=300)
    cmap = sns.diverging_palette(20, 145, as_cmap=True)
    ax = sns.heatmap(
        df * 100,
        annot=True,
        fmt=".0f",
        cmap=cmap,
        cbar=cbar,
        vmin=vmin,
        vmax=vmax,
        center=center,
        square=True,
        annot_kws=annot_kws,
        # annot_kws={"size": 80 / np.sqrt(len(df))},
    )
    ax.set_xlabel(ax.get_xlabel(), fontdict=bold)
    ax.set_ylabel(ax.get_ylabel(), fontdict=bold)

    if save_name:
        plt.savefig(save_name, bbox_inches="tight")


def plot_config(data, title, model):
    df = pd.DataFrame(data)
    df.sort_values(by="Probing Layer", inplace=True)

    n_modes = len(df["concat. mode"].unique())
    n_treebanks = len(df["treebank"].unique())
    if n_modes > 1:
        hue = "concat. mode"
    elif n_treebanks > 1:
        hue = "treebank"
    else:
        hue = None

    xticks = df["Probing Layer"].unique()
    drops = df["Accuracy Decrease (%)"]
    gap = (max(drops) - min(drops)) * 0.75
    if hue:
        gap *= 2

    plt.figure(figsize=(TEXTWIDTH * 1.5, TEXTWIDTH), dpi=300)
    ax = sns.lineplot(
        data=df,
        x="Probing Layer",
        y="Accuracy Decrease (%)",
        style="aggregation",
        hue=hue,
        legend="auto",
        markers=True,
        markersize=20,
        dashes=False,
        alpha=0.5,
        linewidth=2,
        markeredgecolor="black",
    )
    plt.title(f"{title}", fontsize=24, weight="bold")
    ax.set_xlabel(ax.get_xlabel(), fontsize=22)
    ax.set_ylabel(ax.get_ylabel(), fontsize=22)
    plt.xticks(xticks, fontsize=18)
    plt.yticks(fontsize=18)
    plt.legend(fontsize=12)
    # plt.grid()
    plt.ylim(min(drops) - gap, 100)
    plt.savefig(
        f"images/{model}_config_selection_{title}.pdf", bbox_inches="tight", transparent=True
    )

In [None]:
EXPERIMENT_REGEX = r"agg=(\w+)_probe=(\d+)(?:_concat-mode=(\w+))?"

plot_configs = True
plot_heatmaps = False

for MODEL in MODELS:
    for TASK in TASKS:
        data = []
        for TREE_BANK in TREE_BANKS:
            if MODEL == "roberta-base" and TREE_BANK in {"it_vit", "el_gdt"}:
                continue
            print(TASK, MODEL, TREE_BANK)
            experiments_df = utils.get_experiments_df(TASK, TREE_BANK, MODEL)
            if plot_configs:
                for experiment_name in experiments_df.index:
                    avg = experiments_df.loc[experiment_name, "avg"]
                    re_match = re.findall(EXPERIMENT_REGEX, experiment_name)
                    if not re_match:
                        print(f"Weird experiment name, skipping: {experiment_name}")
                        continue
                    match = re_match[0]
                    if len(match) == 3:
                        agg, probe, concat_mode = match
                        data += [
                            {
                                "legend": experiment_name,
                                "aggregation": agg,
                                "Probing Layer": int(probe),
                                "concat. mode": concat_mode,
                                "Accuracy Decrease (%)": int(-avg * 100),
                                "treebank": TREE_BANK,
                            }
                        ]
                    else:
                        agg, probe = match
                        data += [
                            {
                                "legend": experiment_name,
                                "aggregation": agg,
                                "Probe Layer": int(probe),
                                "Accuracy Decrease (%)": int(-avg * 100),
                                "treebank": TREE_BANK,
                            }
                        ]
            if plot_heatmaps:
                plot_heatmap(
                    experiments_df,
                    save_name=f"experiments/{TASK}_{MODEL}_{TREE_BANK}.eps",
                    cbar=False,
                )
                MODE = utils.select_best_mode(experiments_df)
                eval_path = f"lightning_logs/{MODEL}/{TREE_BANK}/{TASK}/{MODE}/evaluation"
                acc_drop = utils.get_acc_drop(eval_path, KEEP_TAGS[TASK])
                plot_heatmap(
                    acc_drop,
                    save_name=f"experiments/{TASK}_{MODEL}_{TREE_BANK}_acc_drop_{MODE}"
                    f"{'_sampled' if KEEP_TAGS[TASK] is not None else ''}"
                    ".eps",
                    vmin=-100,
                    vmax=100,
                    cbar=False,
                )
        if plot_configs:
            plot_config(data, TASK, MODEL)