In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import utils

In [None]:
# which tags we'll keep for each task
KEEP_TAGS = {
    "POS": ["NOUN", "ADJ", "VERB", "PRON", "DET", "NUM", "ADV", "AUX"],
    "DEP": ["PUNCT", "NSUBJ", "OBJ", "OBL", "ADVCL", "CASE", "DET", "AMOD"],
}

In [None]:
# plotting config
sns.set_style("white")
# in inches
TEXTWIDTH = 6.30045
COLWIDTH = 3.03209
# colormap to use
cmap = sns.diverging_palette(20, 145, as_cmap=True)
# dictionary of label font config
label_dict = {"weight": "bold"}

## RoBERTa

Here we plot filtered cross-neutralization for the DEP and POS tasks on the en_gum dataset using embeddings from RoBERTa.

In [None]:
# determining best probing layer and aggregation function
pos_experiments_df = utils.get_experiments_df(
    "POS", "en_gum", "roberta-base", "lightning_logs"
)
pos_mode = utils.select_best_mode(pos_experiments_df)
dep_experiments_df = utils.get_experiments_df(
    "DEP", "en_gum", "roberta-base", "lightning_logs"
)
dep_mode = utils.select_best_mode(dep_experiments_df)

In [None]:
# computing accuracy drop due to neutralisation
pos_eval_path = f"lightning_logs/roberta-base/en_gum/POS/{pos_mode}/evaluation"
pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

dep_eval_path = f"lightning_logs/roberta-base/en_gum/DEP/{dep_mode}/evaluation"
dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])

In [None]:
# plot
sns.set_context("paper", font_scale=0.8)
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(COLWIDTH, COLWIDTH * 2), dpi=300)
ax1 = sns.heatmap(
    pos_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax1,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax1.set_xlabel(None)
ax1.set_ylabel("POS " + ax1.get_ylabel(), fontdict=label_dict)
ax1.tick_params(
    axis="x",
    which="major",
    labelbottom=False,
    bottom=False,
    top=False,
    labeltop=True,
    labelrotation=90,
)

ax2 = sns.heatmap(
    dep_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax2,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax2.set_xlabel(ax2.get_xlabel(), fontdict=label_dict)
ax2.set_ylabel("DEP " + ax2.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
# save our figure
plt.subplots_adjust(hspace=0.05)
plt.savefig("images/roberta-base_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

## XLM-R

Here we plot filtered cross-neutralization for the DEP and POS tasks on the en_gum, it_vit and el_gdt datasets using embeddings from XLM-R.

In [None]:
# Since we'll be doing much of the same stuff for three different languages, use a dict
lang_dict = {"en_gum": {}, "it_vit": {}, "el_gdt": {}}

In [None]:
# encapsulate logic from previous section in a function
def get_acc_drop(treebank, model):
    # determining best probing layer and aggregation function
    pos_experiments_df = utils.get_experiments_df(
        "POS", treebank, model, "lightning_logs"
    )
    pos_mode = utils.select_best_mode(pos_experiments_df)
    dep_experiments_df = utils.get_experiments_df(
        "DEP", treebank, model, "lightning_logs"
    )
    dep_mode = utils.select_best_mode(dep_experiments_df)

    # computing accuracy drop due to neutralisation
    pos_eval_path = f"lightning_logs/{model}/{treebank}/POS/{pos_mode}/evaluation"
    pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

    dep_eval_path = f"lightning_logs/{model}/{treebank}/DEP/{dep_mode}/evaluation"
    dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])

    return pos_acc_drop, dep_acc_drop

In [None]:
# can now populate our dictionary
for key in lang_dict.keys():
    print(key)
    pos_acc_drop, dep_acc_drop = get_acc_drop(key, "xlm-roberta-base")
    lang_dict[key]["POS"] = pos_acc_drop
    lang_dict[key]["DEP"] = dep_acc_drop

In [None]:
# using our dict, we plot our results
sns.set_context("paper", font_scale=0.8)
f, (pos_row, dep_row) = plt.subplots(
    2, 3, figsize=(TEXTWIDTH, (TEXTWIDTH / 3) * 2), dpi=300
)
for i, (treebank, lang) in enumerate(lang_dict.items()):
    ax_pos = pos_row[i]
    ax_dep = dep_row[i]
    pos_acc_drop = lang["POS"]
    dep_acc_drop = lang["DEP"]

    ax_pos = sns.heatmap(
        pos_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_pos,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_pos.tick_params(
        axis="x",
        which="major",
        labelbottom=False,
        bottom=False,
        top=False,
        labeltop=True,
        labelrotation=90,
    )
    ax_pos.set_xlabel(None)
    ax_dep = sns.heatmap(
        dep_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_dep,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_dep.set_xlabel(ax_dep.get_xlabel(), fontdict=label_dict)
    ax_dep.set_title(treebank)
    if i != 0:
        ax_pos.set_ylabel(None)
        ax_dep.set_ylabel(None)
    else:
        ax_pos.set_ylabel("POS " + ax_pos.get_ylabel(), fontdict=label_dict)
        ax_dep.set_ylabel("DEP " + ax_dep.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
plt.subplots_adjust(wspace=0.005, hspace=0.15)
plt.savefig("images/xlm_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

## Cross-Lingual Cross-Neutralisation

Here we plot filtered cross-neutralization for the POS and DEP tasks when neutralising italian embeddings using english centroids, all from XLM-R.

### Main Text

In [None]:
probing_config = {
    "en_gum_from_it_vit": {
        "POS": "agg=max_probe=9",
        "DEP": "agg=first_probe=9_concat-mode=ONLY",
    },
    "en_gum_from_el_gdt": {
        "POS": "agg=max_probe=9",
        "DEP": "agg=first_probe=9_concat-mode=ONLY",
    },
    "it_vit_from_en_gum": {
        "POS": "agg=first_probe=9",
        "DEP": "agg=mean_probe=9_concat-mode=ONLY",
    },
    "it_vit_from_el_gdt": {
        "POS": "agg=first_probe=9",
        "DEP": "agg=mean_probe=9_concat-mode=ONLY",
    },
    "el_gdt_from_en_gum": {
        "POS": "agg=max_probe=9",
        "DEP": "agg=mean_probe=9_concat-mode=ONLY",
    },
    "el_gdt_from_it_vit": {
        "POS": "agg=max_probe=9",
        "DEP": "agg=mean_probe=9_concat-mode=ONLY",
    },
}
treebank2lang = {"en_gum": "English", "it_vit": "Italian", "el_gdt": "Greek"}

In [None]:
treebank = "it_vit_from_en_gum"

In [None]:
# determining best probing layer and aggregation function
pos_experiments_df = utils.get_experiments_df(
    "POS", treebank, "xlm-roberta-base", "lightning_logs"
)
pos_mode = probing_config[treebank]["POS"]
dep_experiments_df = utils.get_experiments_df(
    "DEP", treebank, "xlm-roberta-base", "lightning_logs"
)
dep_mode = probing_config[treebank]["DEP"]

In [None]:
# use slightly different keep tags
KEEP_TAGS = {
    "POS": ["NOUN", "ADJ", "VERB", "PRON", "DET", "NUM", "ADV", "AUX"],
    "DEP": ["CONJ", "NSUBJ", "OBJ", "OBL", "ADVCL", "CASE", "DET", "AMOD"],
}

In [None]:
# computing accuracy drop due to neutralisation
pos_eval_path = f"lightning_logs/xlm-roberta-base/{treebank}/POS/{pos_mode}/evaluation"
pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

dep_eval_path = f"lightning_logs/xlm-roberta-base/{treebank}/DEP/{dep_mode}/evaluation"
dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])

In [None]:
# plot
sns.set_context("paper", font_scale=0.8)
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(COLWIDTH, COLWIDTH * 2), dpi=300)
ax1 = sns.heatmap(
    pos_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax1,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax1.set_xlabel(None)
ax1.set_ylabel("English POS " + ax1.get_ylabel(), fontdict=label_dict)
ax1.tick_params(
    axis="x",
    which="major",
    labelbottom=False,
    bottom=False,
    top=False,
    labeltop=True,
    labelrotation=90,
)

ax2 = sns.heatmap(
    dep_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax2,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax2.set_xlabel("Italian " + ax2.get_xlabel(), fontdict=label_dict)
ax2.set_ylabel("English DEP " + ax2.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
# save our figure
plt.subplots_adjust(hspace=0.05)
plt.savefig(
    "images/cross-lingual_it_from_en_multifigure_sampled.eps", bbox_inches="tight"
)
plt.show()

### Appendix

We now plot the complete (not sampling) X-Lingual X-Neutralisation for each combination, to place them in the report

In [None]:
# re-usable plotting function
def plot_heatmap(
    df,
    neutralizer,
    target,
    save_name=None,
    vmin=None,
    vmax=None,
    center=0.0,
    cbar=True,
    annot_kws={"fontsize": 7},
):
    bold = {"weight": "bold"}

    plt.figure(figsize=(TEXTWIDTH, TEXTWIDTH), dpi=300)
    cmap = sns.diverging_palette(20, 145, as_cmap=True)
    ax = sns.heatmap(
        df * 100,
        annot=True,
        fmt=".0f",
        cmap=cmap,
        cbar=cbar,
        vmin=vmin,
        vmax=vmax,
        center=center,
        square=True,
        annot_kws=annot_kws,
        # annot_kws={"size": 80 / np.sqrt(len(df))},
    )
    ax.set_xlabel(target + " " + ax.get_xlabel(), fontdict=bold)
    ax.set_ylabel(neutralizer + " " + ax.get_ylabel(), fontdict=bold)

    if save_name:
        plt.savefig(save_name, bbox_inches="tight")

In [None]:
for treebank in probing_config.keys():
    target_tb, neutralizer_tb = treebank.split("_from_")
    target = treebank2lang[target_tb]
    neutralizer = treebank2lang[neutralizer_tb]
    for task in ["POS", "DEP"]:
        print(task, treebank)
        experiments_df = utils.get_experiments_df(task, treebank, "xlm-roberta-base")
        mode = probing_config[treebank][task]
        eval_path = (
            f"lightning_logs/xlm-roberta-base/{treebank}/{task}/{mode}/evaluation"
        )
        acc_drop = utils.get_acc_drop(eval_path, None)
        plot_heatmap(
            acc_drop,
            neutralizer,
            target,
            save_name=f"experiments/{task}-crosslingual-{treebank}-accdrop.eps",
            vmin=-100,
            vmax=100,
            cbar=False,
        )

## Best Probing Configurations

We now report the what we found to be the best probing configuration for each language, model and task combination

In [None]:
data = {
    "POS": {
        "roberta-base": {"en_gum": None},
        "xlm-roberta-base": {"en_gum": None, "it_vit": None, "el_gdt": None},
    },
    "DEP": {
        "roberta-base": {"en_gum": None},
        "xlm-roberta-base": {"en_gum": None, "it_vit": None, "el_gdt": None},
    },
}

tasks = ["POS", "DEP"]
models = ["roberta-base", "xlm-roberta-base"]
treebanks = ["en_gum", "it_vit", "el_gdt"]

for task in tasks:
    for model in models:
        for treebank in treebanks:
            if treebank not in data[task][model]:
                continue
            else:
                experiments_df = utils.get_experiments_df(
                    task, treebank, model, "lightning_logs"
                )
                mode = utils.select_best_mode(experiments_df)
                print(task, model, treebank, mode)
                data[task][model][treebank] = mode

In [None]:
import pandas as pd

In [None]:
def parse_config_str(config):
    if type(config) != str:
        return "-"
    info = config.split("_")
    agg, layer = [el.split("=")[1] for el in info[:2]]
    result = f"Aggregation: {agg}; layer: {layer}"
    if len(info) == 3:
        concat = info[2].split("=")[1]
        result += f"; Concatenation: {concat}"
    return result

In [None]:
pandas_data = {
    (task, model): config
    for task, innerDict in data.items()
    for model, config in innerDict.items()
}
pd.DataFrame(pandas_data).applymap(parse_config_str)