## Cross Neutralizing

In [None]:
# NOTE: Change these accordingly
TASKS = ["POS"]
TREE_BANKS = ["en_gum", "it_vit", "el_gdt"]
MODELS = ["xlm-roberta-base"]
# Set the values to None if you want to keep all the tags
KEEP_TAGS = {
    "POS": None,
    # "POS": [
    #     "NOUN",
    #     "ADJ",
    #     "VERB",
    #     "PRON", 
    #     "DET",
    #     "NUM",
    #     "ADV",
    #     "AUX"
    # ],
    "DEP": None,
}

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import seaborn as sns

sns.set_style("white")
sns.set_context("paper", font_scale=2)

In [None]:
TAG_REGEX = r".*\/evaluation_(\w+)\/.*"
EXPERIMENT_REGEX = r".*\/(.+=.+)\/.*"


def get_base_series(path):
    if "*" in path:
        path = glob.glob(path)[0]

    ea = EventAccumulator(path)
    ea.Reload()

    scalars = {}
    for metric in ea.Tags()["scalars"]:
        if "test" not in metric:
            continue

        scalar = ea.Scalars(metric)[0].value

        parts = metric.split("_")
        if len(parts) == 2:
            tag = "avg"
        else:
            tag = parts[-1].upper()

        scalars[tag] = scalar

    return pd.Series(scalars)


def get_xn_df(glob_path):
    neutralizers = {}
    for path in glob.glob(glob_path):
        res = re.search(TAG_REGEX, path)
        base_tag = res.group(1).upper()

        ea = EventAccumulator(path)
        ea.Reload()

        scalars = {}
        for metric in ea.Tags()["scalars"]:
            if "test" not in metric:
                continue

            scalar = ea.Scalars(metric)[0].value

            if "/" not in metric:
                tag = "avg"
            else:
                tag = metric.split("_")[-1].upper()

            scalars[tag] = scalar

        neutralizers[base_tag] = scalars

    df = pd.DataFrame(neutralizers)
    df.index.name = "Target"
    df = df.T
    df.index.name = "Neutralizer"
    return df


def get_acc_drop(eval_path, keep_cols=None):
    base_series = get_base_series(f"{eval_path}/events*")
    xn_df = get_xn_df(f"{eval_path}_*/events*")

    nulls = base_series.isnull()
    base_series = base_series[~nulls]
    xn_df = xn_df.T[~nulls].T

    acc_drop = (xn_df - base_series) / base_series
    acc_drop.sort_index(axis=0, inplace=True)
    acc_drop.sort_index(axis=1, inplace=True)

    drop_indices = set(acc_drop.T.columns).difference(acc_drop.columns)
    drop_columns = (
        set(acc_drop.columns).difference(acc_drop.T.columns).difference(["avg"])
    )

    for hidden in ("DEP", "APPOS"):
        if hidden in acc_drop.T.columns:
            drop_indices.add(hidden)
        if hidden in acc_drop.columns:
            drop_columns.add(hidden)

    acc_drop.drop(index=drop_indices, inplace=True)
    acc_drop.drop(columns=drop_columns, inplace=True)
    if keep_cols is not None:
        acc_drop = (acc_drop.loc[keep_cols])[keep_cols]
    return acc_drop


def get_experiments_df(task, treebank, model):
    experiments = {}
    for path in glob.glob(f"../lightning_logs/{model}/{treebank}/{task}/*/evaluation"):
        res = re.search(EXPERIMENT_REGEX, path)
        experiment = res.group(1)
        acc_drop = get_acc_drop(path)

        self_neutr = {}
        for tag in acc_drop.columns.values[:-1]:
            tag = tag.upper()
            self_neutr[tag] = acc_drop[tag][tag]

        self_neutr["avg"] = np.nanmean(list(self_neutr.values()))
        experiments[experiment] = self_neutr

    df = pd.DataFrame(experiments)
    df.sort_values(by="avg", axis=1, inplace=True)

    nulls = df.isnull().all(1)
    df = df[~nulls]

    df.index.name = "Neutralizer"
    df = df.T
    df.index.name = "Experiment"
    return df


def plot_heatmap(df, save_name=None, vmin=None, vmax=None, center=0.0):
    bold = {"weight": "bold"}

    plt.figure(figsize=(20, 20), dpi=300)
    cmap = sns.diverging_palette(20, 145, as_cmap=True)
    ax = sns.heatmap(
        df * 100,
        annot=True,
        fmt=".0f",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        center=center,
        square=True,
        cbar_kws={"shrink": 0.75},
        annot_kws={"size": 80 / np.sqrt(len(df))},
    )
    ax.set_xlabel(ax.get_xlabel(), fontdict=bold)
    ax.set_ylabel(ax.get_ylabel(), fontdict=bold)

    if save_name:
        plt.savefig(save_name)

In [None]:
for TASK in TASKS:
    for MODEL in MODELS:
        for TREE_BANK in TREE_BANKS:
            print(TASK, MODEL, TREE_BANK)
            experiments_df = get_experiments_df(TASK, TREE_BANK, MODEL)
            plot_heatmap(
                experiments_df,
                save_name=f"{TASK}_{MODEL}_{TREE_BANK}.eps",
            )
            MODE = experiments_df.index[0]
            eval_path = (
                f"../lightning_logs/{MODEL}/{TREE_BANK}/{TASK}/{MODE}/evaluation"
            )
            acc_drop = get_acc_drop(eval_path, KEEP_TAGS[TASK])
            plot_heatmap(
                acc_drop,
                save_name=f"{TASK}_{MODEL}_{TREE_BANK}_acc_drop_{MODE}"
                f"{'_sampled' if KEEP_TAGS[TASK] is not None else ''}"
                ".eps",
                vmin=-100,
                vmax=100,
            )