In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import utils

In [None]:
# which tags we'll keep for each task
KEEP_TAGS = {
    "POS": ["NOUN", "ADJ", "VERB", "PRON", "DET", "NUM", "ADV", "AUX"],
    "DEP": ["PUNCT", "NSUBJ", "OBJ", "OBL", "ADVCL", "CASE", "DET", "AMOD"],
}

In [None]:
# plotting config
sns.set_style("white")
# in inches
TEXTWIDTH = 6.30045
COLWIDTH = 3.03209
# colormap to use
cmap = sns.diverging_palette(20, 145, as_cmap=True)
# dictionary of label font config
label_dict = {"weight": "bold"}

# RoBERTa

Here we plot filtered cross-neutralization for the DEP and POS tasks on the en_gum dataset using embeddings from RoBERTa.

In [None]:
# determining best probing layer and aggregation function
pos_experiments_df = utils.get_experiments_df(
    "POS", "en_gum", "roberta-base", "lightning_logs"
)
pos_mode = utils.select_best_mode(pos_experiments_df)
dep_experiments_df = utils.get_experiments_df(
    "DEP", "en_gum", "roberta-base", "lightning_logs"
)
dep_mode = utils.select_best_mode(dep_experiments_df)

In [None]:
# computing accuracy drop due to neutralisation
pos_eval_path = f"lightning_logs/roberta-base/en_gum/POS/{pos_mode}/evaluation"
pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

dep_eval_path = f"lightning_logs/roberta-base/en_gum/DEP/{dep_mode}/evaluation"
dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])

In [None]:
# plot
sns.set_context("paper", font_scale=0.8)
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(COLWIDTH, COLWIDTH * 2), dpi=300)
ax1 = sns.heatmap(
    pos_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax1,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax1.set_xlabel(None)
ax1.set_ylabel("POS " + ax1.get_ylabel(), fontdict=label_dict)
ax1.tick_params(
    axis="x",
    which="major",
    labelbottom=False,
    bottom=False,
    top=False,
    labeltop=True,
    labelrotation=90,
)

ax2 = sns.heatmap(
    dep_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax2,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax2.set_xlabel(ax2.get_xlabel(), fontdict=label_dict)
ax2.set_ylabel("DEP " + ax2.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
# save our figure
plt.subplots_adjust(hspace=0.05)
plt.savefig("images/roberta-base_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

## XLM-R

Here we plot filtered cross-neutralization for the DEP and POS tasks on the en_gum, it_vit and el_gdt datasets using embeddings from XLM-R.

In [None]:
# Since we'll be doing much of the same stuff for three different languages, use a dict
lang_dict = {"en_gum": {}, "it_vit": {}, "el_gdt": {}}

In [None]:
# encapsulate logic from previous section in a function
def get_acc_drop(treebank, model):
    # determining best probing layer and aggregation function
    pos_experiments_df = utils.get_experiments_df(
        "POS", treebank, model, "lightning_logs"
    )
    pos_mode = utils.select_best_mode(pos_experiments_df)
    dep_experiments_df = utils.get_experiments_df(
        "DEP", treebank, model, "lightning_logs"
    )
    dep_mode = utils.select_best_mode(dep_experiments_df)
    
    # computing accuracy drop due to neutralisation
    pos_eval_path = f"lightning_logs/{model}/{treebank}/POS/{pos_mode}/evaluation"
    pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

    dep_eval_path = f"lightning_logs/{model}/{treebank}/DEP/{dep_mode}/evaluation"
    dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])
    
    return pos_acc_drop, dep_acc_drop

In [None]:
# can now populate our dictionary
for key in lang_dict.keys():
    print(key)
    pos_acc_drop, dep_acc_drop = get_acc_drop(key, "xlm-roberta-base")
    lang_dict[key]["POS"] = pos_acc_drop
    lang_dict[key]["DEP"] = dep_acc_drop

In [None]:
# using our dict, we plot our results
sns.set_context("paper", font_scale=0.8)
f, (pos_row, dep_row) = plt.subplots(
    2, 3, figsize=(TEXTWIDTH, (TEXTWIDTH / 3) * 2), dpi=300
)
for i, (treebank, lang) in enumerate(lang_dict.items()):
    ax_pos = pos_row[i]
    ax_dep = dep_row[i]
    pos_acc_drop = lang["POS"]
    dep_acc_drop = lang["DEP"]

    ax_pos = sns.heatmap(
        pos_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},  
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_pos,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_pos.tick_params(
        axis="x",
        which="major",
        labelbottom=False,
        bottom=False,
        top=False,
        labeltop=True,
        labelrotation=90
    )
    ax_pos.set_xlabel(None)
    ax_dep = sns.heatmap(
        dep_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_dep,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_dep.set_xlabel(ax_dep.get_xlabel(), fontdict=label_dict)
    ax_dep.set_title(treebank)
    if i != 0:
        ax_pos.set_ylabel(None)
        ax_dep.set_ylabel(None)
    else:
        ax_pos.set_ylabel("POS " + ax_pos.get_ylabel(), fontdict=label_dict)
        ax_dep.set_ylabel("DEP " + ax_dep.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
plt.subplots_adjust(wspace=0.005, hspace=0.15)
plt.savefig("images/xlm_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

## Cross-Lingual Cross-Neutralisation

Here we plot filtered cross-neutralization for the POS task when neutralising italian using english embeddings and when neutralising greek using english embeddings from XLM-R.

In [None]:
# determining best probing layer and aggregation function
it_experiments_df = utils.get_experiments_df(
    "POS", "it_vit_from_en_gum", "xlm-roberta-base", "lightning_logs"
)
it_mode = "agg=first_probe=9"
gr_experiments_df = utils.get_experiments_df(
    "POS", "el_gdt_from_en_gum", "xlm-roberta-base", "lightning_logs"
)
gr_mode = "agg=max_probe=9"

In [None]:
# computing accuracy drop due to neutralisation
it_eval_path = (
    f"lightning_logs/xlm-roberta-base/it_vit_from_en_gum/POS/{it_mode}/evaluation"
)
it_acc_drop = utils.get_acc_drop(it_eval_path, KEEP_TAGS["POS"])

gr_eval_path = (
    f"lightning_logs/xlm-roberta-base/el_gdt_from_en_gum/POS/{gr_mode}/evaluation"
)
gr_acc_drop = utils.get_acc_drop(gr_eval_path, KEEP_TAGS["POS"])

In [None]:
# plot
sns.set_context("paper", font_scale=0.8)
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(COLWIDTH, COLWIDTH * 2), dpi=300)
ax1 = sns.heatmap(
    it_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax1,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax1.set_xlabel(None)
ax1.set_ylabel("Italian Neutralization by English", fontdict=label_dict)
ax1.tick_params(
    axis="x",
    which="major",
    labelbottom=False,
    bottom=False,
    top=False,
    labeltop=True,
    labelrotation=90,
)

ax2 = sns.heatmap(
    gr_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax2,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax2.set_xlabel(ax2.get_xlabel(), fontdict=label_dict)
ax2.set_ylabel("Greek Neutralization by English", fontdict=label_dict)
# f.set_tight_layout(True)
# save our figure
plt.subplots_adjust(hspace=0.05)
plt.savefig("images/cross-lingual_pos_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

Let's also plot them separately in case we only wish to include one in the report

In [None]:
# Italian
f, ax1 = plt.subplots(1, 1, figsize=(COLWIDTH, COLWIDTH), dpi=300)
ax1 = sns.heatmap(
    it_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax1,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax1.set_xlabel(ax1.get_xlabel(), fontdict=label_dict)
ax1.set_ylabel("Italian Neutralization by English", fontdict=label_dict)
plt.savefig("images/cross-lingual_pos_italian_sampled.eps", bbox_inches="tight")
plt.show()

In [None]:
# Greek
f, ax2 = plt.subplots(1, 1, figsize=(COLWIDTH, COLWIDTH), dpi=300)
ax2 = sns.heatmap(
    gr_acc_drop * 100,
    annot=True,
    annot_kws={"fontsize": 7},
    fmt=".0f",
    cmap=cmap,
    cbar=False,
    ax=ax2,
    vmin=-100,
    vmax=100,
    center=0,
    square=True,
)
ax2.set_xlabel(ax2.get_xlabel(), fontdict=label_dict)
ax2.set_ylabel("Greek Neutralization by English", fontdict=label_dict)
plt.savefig("images/cross-lingual_pos_greek_sampled.eps", bbox_inches="tight")
plt.show()

## XLM-R

Here we plot filtered cross-neutralization for the DEP and POS tasks on the en_gum, it_vit and el_gdt datasets using embeddings from XLM-R.

In [None]:
# Since we'll be doing much of the same stuff for three different languages, use a dict
lang_dict = {"en_gum": {}, "it_vit": {}, "el_gdt": {}}

In [None]:
# encapsulate logic from previous section in a function
def get_acc_drop(treebank, model):
    # determining best probing layer and aggregation function
    pos_experiments_df = utils.get_experiments_df(
        "POS", treebank, model, "lightning_logs"
    )
    pos_mode = pos_experiments_df.index[0]
    dep_experiments_df = utils.get_experiments_df(
        "DEP", treebank, model, "lightning_logs"
    )
    dep_mode = dep_experiments_df.index[0]
    
    # computing accuracy drop due to neutralisation
    pos_eval_path = f"lightning_logs/{model}/{treebank}/POS/{pos_mode}/evaluation"
    pos_acc_drop = utils.get_acc_drop(pos_eval_path, KEEP_TAGS["POS"])

    dep_eval_path = f"lightning_logs/{model}/{treebank}/DEP/{dep_mode}/evaluation"
    dep_acc_drop = utils.get_acc_drop(dep_eval_path, KEEP_TAGS["DEP"])
    
    return pos_acc_drop, dep_acc_drop

In [None]:
# can now populate our dictionary
for key in lang_dict.keys():
    print(key)
    pos_acc_drop, dep_acc_drop = get_acc_drop(key, "xlm-roberta-base")
    lang_dict[key]["POS"] = pos_acc_drop
    lang_dict[key]["DEP"] = dep_acc_drop

In [None]:
# using our dict, we plot our results
sns.set_context("paper", font_scale=0.8)
f, (pos_row, dep_row) = plt.subplots(
    2, 3, figsize=(TEXTWIDTH, (TEXTWIDTH / 3) * 2), dpi=300
)
for i, (treebank, lang) in enumerate(lang_dict.items()):
    ax_pos = pos_row[i]
    ax_dep = dep_row[i]
    pos_acc_drop = lang["POS"]
    dep_acc_drop = lang["DEP"]

    ax_pos = sns.heatmap(
        pos_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},  
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_pos,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_pos.tick_params(
        axis="x",
        which="major",
        labelbottom=False,
        bottom=False,
        top=False,
        labeltop=True,
        labelrotation=90
    )
    ax_pos.set_xlabel(None)
    ax_dep = sns.heatmap(
        dep_acc_drop * 100,
        annot=True,
        annot_kws={"fontsize": 6},
        fmt=".0f",
        cmap=cmap,
        cbar=False,
        ax=ax_dep,
        vmin=-100,
        vmax=100,
        center=0,
        square=True,
        yticklabels=False if i != 0 else True,
    )
    ax_dep.set_xlabel(ax_dep.get_xlabel(), fontdict=label_dict)
    ax_dep.set_title(treebank)
    if i != 0:
        ax_pos.set_ylabel(None)
        ax_dep.set_ylabel(None)
    else:
        ax_pos.set_ylabel("POS " + ax_pos.get_ylabel(), fontdict=label_dict)
        ax_dep.set_ylabel("DEP " + ax_dep.get_ylabel(), fontdict=label_dict)
# f.set_tight_layout(True)
plt.subplots_adjust(wspace=0.005, hspace=0.15)
plt.savefig("images/xlm_multifigure_sampled.eps", bbox_inches="tight")
plt.show()

## Cross-Lingual Cross-Neutralisation

TODO