In [1]:
import pandas as pd
import numpy as np
import tqdm
import matplotlib.pyplot as plt

import gensim
import nltk
import textdistance
from nltk.stem import WordNetLemmatizer
from scipy.stats import pearsonr, kendalltau, spearmanr
import seaborn as sns

import rsa, data_utils, model_utils, representations

KeyboardInterrupt: 

In [None]:
sentences = data_utils.get_noun_noun_compound_sentences(data_loc='../data')
mod_head_words_per_sentence = data_utils.get_noun_noun_mod_head_words_per_sentence(data_loc='../data')
corrected_form_compounds_per_sentence = data_utils.load_corrected_form_compounds_per_sentence(data_loc='../data')

In [None]:
len(sentences)

In [None]:
# paraphrase_ind_tuples = [[i, i+300, i+600] for i in range(300)]
# paraphrase_inds = [item for sublist in paraphrase_ind_tuples for item in sublist]

# ordered_sentences = sentences[paraphrase_inds]

In [None]:
models = list(model_utils.dev_model_configs.keys())

In [None]:
models

## Load data

In [None]:
load = False

model_input_id_dict = {}
model_tokens_dict = {}
model_head_noun_locs_per_sent = {}
model_mod_word_locs_per_sent = {}

for model_name in tqdm.tqdm(models):
# for model_name in [model_name]:
    if load:
        pass
    else:
        model, tokeniser = model_utils.load_model(model_name)

        unpack_dict = lambda x: (x['input_ids'], x['attention_mask'])
        input_ids, attention_mask = unpack_dict(tokeniser.batch_encode_plus(sentences, max_length=512, return_tensors='pt', pad_to_max_length=True))
        
        pad_token_mask = lambda x: np.array(x.cpu() == tokeniser.pad_token_id)
        get_tokens_to_keep = lambda x: np.argwhere((pad_token_mask(x) == False)).reshape(-1)
        decode_tokens = lambda x: [tokeniser.decode(token).replace(' ', '') for token in x[get_tokens_to_keep(x)].tolist()]
        
        input_ids_to_keep = [x[get_tokens_to_keep(x)] for x in input_ids]
        model_input_id_dict[model_name] = input_ids_to_keep
        
        model_tokens_dict[model_name] = [decode_tokens(x) for x in input_ids]
        
        compounds_per_sample = np.array(corrected_form_compounds_per_sentence)
        if 'xlm' in model_name:
            head_noun_input_ids_per_sent_raw = np.array(tokeniser.batch_encode_plus(compounds_per_sample[:, 1])['input_ids'])
            mod_word_input_ids_per_sent_raw = np.array(tokeniser.batch_encode_plus(compounds_per_sample[:, 0])['input_ids'])
            
            head_noun_input_ids_per_sent = [np.array(x[1:-1]).reshape(-1, 1) for x in head_noun_input_ids_per_sent_raw]
            mod_word_input_ids_per_sent = [np.array(x[1:-1]).reshape(-1, 1) for x in mod_word_input_ids_per_sent_raw]
        else:
            head_noun_input_ids_per_sent_raw = np.array(tokeniser.batch_encode_plus([' ' + x for x in compounds_per_sample[:, 1]])['input_ids'])
            mod_word_input_ids_per_sent_raw = np.array(tokeniser.batch_encode_plus([' ' + x for x in compounds_per_sample[:, 0]])['input_ids'])
            
            # Remove special tokens
            non_special_token_mask = lambda x: np.array(tokeniser.get_special_tokens_mask(x, already_has_special_tokens=True)) == 0
            get_tokens_to_keep = lambda x: np.argwhere(non_special_token_mask(x)).reshape(-1)
            head_noun_input_ids_per_sent = [np.array(x)[get_tokens_to_keep(x)] for x in head_noun_input_ids_per_sent_raw]
            mod_word_input_ids_per_sent = [np.array(x)[get_tokens_to_keep(x)] for x in mod_word_input_ids_per_sent_raw]

        head_noun_locs_per_sent = [representations.search_sequence_numpy(input_ids_to_keep[i].cpu().numpy().reshape(-1), x.reshape(-1)) for i, x in enumerate(head_noun_input_ids_per_sent)]
        mod_word_locs_per_sent = [representations.search_sequence_numpy(input_ids_to_keep[i].cpu().numpy().reshape(-1), x.reshape(-1)) for i, x in enumerate(mod_word_input_ids_per_sent)]
        
        model_head_noun_locs_per_sent[model_name] = np.array(head_noun_locs_per_sent)
        model_mod_word_locs_per_sent[model_name] = np.array(mod_word_locs_per_sent)        

In [None]:
[(i,x) for i, x in enumerate(sentences) if 'allerg' in x]

In [None]:
def get_attention_per_sample(model, layer):
    file_name = "../data/representations/nn_compounds_attention/{}_layer_{}_noun_noun_compounds_attention.npy".format(model, layer)

    flat_attention_per_sample = [x[np.argwhere(x != -1)] for x in np.load(file_name)]
    attention_per_sample = [x.reshape((int(np.sqrt(x.size)), int(np.sqrt(x.size)))) for x in flat_attention_per_sample]
    
    return attention_per_sample

In [None]:
def plot_attention(attention, labels, ax=None, title=''):
    g = sns.heatmap(attention, cmap="viridis", ax=ax)
    g.set_yticklabels(labels, rotation=0)
    g.set_xticklabels(labels, rotation=45)
    if ax != None:
        ax.set_title(title)

In [None]:
def plot_attention_over_layers(model_name, sample_i, layers=[1, 4, 8, 12]):
    fig, axs = plt.subplots(ncols=len(layers), figsize=(6*len(layers), 6))

    for i, layer in enumerate(layers):
        if 'distil' in model_name and layer > 6:
            continue
        plt.cla()
        attention_per_sample = get_attention_per_sample(model_name, layer)
        plot_attention(attention_per_sample[sample_i], model_tokens_dict[model_name][sample_i], ax=axs[i], title='model={}, layer={}'.format(model_name, layer))

In [None]:
sample_i = 165

In [None]:
plot_attention_over_layers('roberta-base', sample_i)

In [None]:
plot_attention_over_layers('bert-base-uncased', sample_i)

In [None]:
plot_attention_over_layers('xlnet-base-cased', sample_i)

In [None]:
plot_attention_over_layers('xlm-mlm-xnli15-1024', 70)

In [None]:
plot_attention_over_layers('xlm-mlm-xnli15-1024', sample_i)

## Calculating Attention Mass

In [None]:
model_layer_attention_dict = {}

for model_name in model_utils.dev_model_configs.keys():
    
    model_layer_attention_dict[model_name] = {}
    
    for layer_i in range(1, 13):
        if layer_i > 6 and 'distil' in model_name:
            continue
        attention_per_sample = get_attention_per_sample(model_name, layer_i)
        model_layer_attention_dict[model_name][layer_i] = attention_per_sample

In [None]:
model_name = 'bert-base-uncased'

input_ids = model_input_id_dict[model_name]
tokens = model_tokens_dict[model_name]
head_noun_locs_per_sent = model_head_noun_locs_per_sent[model_name]
mod_word_locs_per_sent = model_mod_word_locs_per_sent[model_name]

In [None]:
sentence = sentences[0]
sentence_tokens = tokens[0]
sentence_input_ids = input_ids[0]
sentence_head_noun_locs_per_sent = head_noun_locs_per_sent[0]
sentence_mod_word_locs_per_sent = mod_word_locs_per_sent[0]
sentence_attention = model_layer_attention_dict[model_name][1][0]

print(sentence)
print(sentence_tokens)
print(sentence_input_ids)
print(sentence_head_noun_locs_per_sent)
print(sentence_mod_word_locs_per_sent)

plot_attention(sentence_attention, sentence_tokens)

### Demonstrating how we can select within the attention matrix using masks

In [None]:
head_attention = sentence_attention[sentence_head_noun_locs_per_sent]
head_head_attention = head_attention[:, sentence_head_noun_locs_per_sent]

head_mask = np.zeros(sentence_attention.shape,dtype=bool)
head_mask[sentence_head_noun_locs_per_sent] = True

fig, axs = plt.subplots(ncols=3, figsize=(10, 10))
axs[0].imshow(head_mask)
axs[1].imshow(head_mask & head_mask.T)
axs[2].imshow(head_mask & ~head_mask.T)
axs[0].set_title('head_mask')
axs[1].set_title('head_mask & head_mask.T')
axs[2].set_title('head_mask & ~head_mask.T')

In [None]:
head_mask = np.zeros(sentence_attention.shape,dtype=bool)
head_mask[sentence_head_noun_locs_per_sent] = True

mod_mask = np.zeros(sentence_attention.shape,dtype=bool)
mod_mask[sentence_mod_word_locs_per_sent] = True

compound_mask = mod_mask | head_mask
plt.imshow(compound_mask & compound_mask)

In [None]:
compound_mask[:, 0].sum()

In [None]:
def attention_proportion_in_masks(sentence_attention, mask_a, mask_b):
    # e.g mask_a = head_mask, mask_b = mod_mask will select proportion of head_attention within modifier words
    # Spans are number of tokens the mask spans
    mask_a_span = mask_a[:, 0].sum()
    mask_b_span = mask_b[:, 0].sum()
    return np.mean(sentence_attention[mask_a & mask_b.T].reshape(mask_a_span, mask_b_span).sum(axis=1))

In [None]:
sentence_attention[head_mask & head_mask.T]

In [None]:
attention_proportion_in_masks(sentence_attention, head_mask, head_mask)

## Working out compound/multi-token attention

In [None]:
# In this example, the modifier word is at index 3 and the head noun is at index 4
sentence_attention

In [None]:
# Attention for compound tokens
sentence_attention[compound_mask].reshape(-1, sentence_attention.shape[0])

In [None]:
# Compound-compound attention, same as sentence_attention[compound_mask & compound_mask.T].reshape(2, 2) below
sentence_attention[compound_mask].reshape(-1, sentence_attention.shape[0])[:, 3:5]

In [None]:
compound_mask & compound_mask.T

In [None]:
sentence_attention[compound_mask & compound_mask.T]

In [None]:
sentence_attention[compound_mask & compound_mask.T].reshape(2, 2)

In [None]:
# Sanity check
sentence_attention[compound_mask].reshape(-1, sentence_attention.shape[0]).sum(axis=1)

In [None]:
# Proportion of attention in each compound token that is within the whole compound
sentence_attention[compound_mask & compound_mask.T].reshape(2, 2).sum(axis=1)

In [None]:
# Average compound-compound attention
np.mean(sentence_attention[compound_mask & compound_mask.T].reshape(2, 2).sum(axis=1))

In [None]:
attention_proportion_in_masks(sentence_attention, compound_mask, compound_mask)

In [None]:
def get_mask_dicts(model_name):
    mask_dicts = []
    
    for sample_i in range(len(sentences)):
        sentence_attention = model_layer_attention_dict[model_name][1][sample_i]
        
        head_mask = np.zeros(sentence_attention.shape,dtype=bool)
        head_mask[model_head_noun_locs_per_sent[model_name][sample_i]] = True

        mod_mask = np.zeros(sentence_attention.shape,dtype=bool)
        mod_mask[model_mod_word_locs_per_sent[model_name][sample_i]] = True

        compound_mask = mod_mask | head_mask
        
        mask_dicts.append({'head': head_mask, 'modifier': mod_mask, 'compound': compound_mask})
    
    return mask_dicts

In [None]:
plt.imshow(sentence_attention)

In [None]:
mask_dicts = []

for sample_i in range(len(sentences)):
    sentence_attention = model_layer_attention_dict[model_name][1][sample_i]

    head_mask = np.zeros(sentence_attention.shape,dtype=bool)
    head_mask[model_head_noun_locs_per_sent[model_name][sample_i]] = True

    mod_mask = np.zeros(sentence_attention.shape,dtype=bool)
    mod_mask[model_mod_word_locs_per_sent[model_name][sample_i]] = True

    compound_mask = mod_mask | head_mask

    mask_dicts.append({'head': head_mask, 'modifier': mod_mask, 'compound': compound_mask})



## Calculating attention proportion features

In [None]:
def calculate_attention_feature(name, attention, mask_dict):
    attention_target, attention_source = name.split('_')
    return attention_proportion_in_masks(attention, mask_dict[attention_target], mask_dict[attention_source])

In [None]:
# mask_name_dict = ['compound', "head", "modifier"]
mask_name_dict = ["head", "modifier"]

# Each function will take an attention matrix and a dict of masks for that sample, and select and apply the appropriate masks to calculate the appropriate proportion
feature_names = []

for mask_a_name in mask_name_dict:
    for mask_b_name in mask_name_dict:
        feature_names.append('{}_{}'.format(mask_a_name, mask_b_name))
        
print(feature_names)

In [None]:
model_layer_feature_dict = {}

for model_name in model_layer_attention_dict.keys():
# for model_name in ['bert-base-uncased', 'roberta-base', 'xlnet-base-cased']:
    print(model_name)
    model_layer_feature_dict[model_name] = {}
    
    mask_dicts = get_mask_dicts(model_name)
    
    for layer in tqdm.tqdm(model_layer_attention_dict[model_name].keys()):
        model_layer_feature_dict[model_name][layer] = {}
        
        for feature in feature_names:
            model_layer_feature_dict[model_name][layer][feature] = np.mean([calculate_attention_feature(feature, model_layer_attention_dict[model_name][layer][sample_i], mask_dicts[sample_i]) for sample_i in range(len(sentences))])

In [None]:
plt.imshow(sentence_attention)

In [None]:
sentence_attention.shape

In [None]:
pd.DataFrame.from_dict(model_layer_feature_dict['bert-base-uncased'])

In [None]:
rows = []

for model_name in model_layer_attention_dict.keys():
# for model_name in ['bert-base-uncased', 'roberta-base', 'xlnet-base-cased']:
    print(model_name)
    
    mask_dicts = get_mask_dicts(model_name)
    
    for layer in tqdm.tqdm(model_layer_attention_dict[model_name].keys()):
        
        for feature in feature_names:
            feature_mean_val = np.mean([calculate_attention_feature(feature, model_layer_attention_dict[model_name][layer][sample_i], mask_dicts[sample_i]) for sample_i in range(len(sentences))])
            rows.append({'model': model_name, 'layer': layer, 'feature': feature, 'mean_attention_proportion': feature_mean_val, 'representation_target': feature.split('_')[0], 'representation_source': feature.split('_')[1]})

feature_df = pd.DataFrame(rows)

In [None]:
feature_df

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'bert-base-uncased'])
plt.show()

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'roberta-base'])
plt.show()

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'xlnet-base-cased'])
plt.show()

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'xlm-mlm-xnli15-1024'])
plt.show()

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'distilroberta-base'])
plt.show()

In [None]:

plt.figure(figsize=(14, 10))
sns.lineplot(x='layer', y='mean_attention_proportion', hue='representation_target', style='representation_source', markers=True, data=feature_df[feature_df.model == 'xlm-mlm-xnli15-1024'])
plt.show()

In [None]:
df = pd.read_csv('../data/results/nn_compound_transformer_relations_per_word.csv')
df

In [None]:
fig, ax = plt.subplots(figsize=(20, 15))
sns.lineplot(x='layer',y='mean_attention_proportion', style='feature', hue='model', data=feature_df, ax=ax)

In [None]:
layer_selection = 'start' # ['all', 'middle', 'not_middle', 'start', 'end']
start_layer_boundary = 2
end_layer_boundary = 10

rows = []

for model_name in list(model_utils.dev_model_configs.keys()):
    for layer_selection in ['all', 'middle', 'not_middle', 'start', 'end']:
        for experimental_condition in [x for x in df.columns if 'corr' in x]:
            for attention_feature in feature_df.feature.unique():
                for representation in ['compound_mean', 'head_noun_mean', 'mod_word_mean']:
                    if layer_selection == 'all':
                        attention_feature_values = feature_df.sort_values(['model', 'layer'])[(feature_df.model == model_name) & (feature_df.feature==attention_feature)].mean_attention_proportion.tolist()
                        transformer_same_relation_corr_str_for_rep = df.sort_values(['model', 'layer'])[(df.model == model_name) & (df.representation==representation)][experimental_condition].tolist()
                    if layer_selection == 'middle':
                        attention_feature_values = feature_df.sort_values(['model', 'layer'])[(feature_df.model == model_name) & (feature_df.feature==attention_feature) & ((feature_df.layer > start_layer_boundary) & (feature_df.layer <= end_layer_boundary))].mean_attention_proportion.tolist()
                        transformer_same_relation_corr_str_for_rep = df.sort_values(['model', 'layer'])[((df.model == model_name) & (df.representation==representation) & ((df.layer > start_layer_boundary) & (df.layer <= end_layer_boundary)))][experimental_condition].tolist()
                    if layer_selection == 'not_middle':
                        attention_feature_values = feature_df.sort_values(['model', 'layer'])[(feature_df.model == model_name) & (feature_df.feature==attention_feature) & ((feature_df.layer <= start_layer_boundary) | (feature_df.layer > end_layer_boundary))].mean_attention_proportion.tolist()
                        transformer_same_relation_corr_str_for_rep = df.sort_values(['model', 'layer'])[((df.model == model_name) & (df.representation==representation) & ((df.layer <= start_layer_boundary) | (df.layer > end_layer_boundary)))][experimental_condition].tolist()
                    if layer_selection == 'start':
                        attention_feature_values = feature_df.sort_values(['model', 'layer'])[(feature_df.model == model_name) & (feature_df.feature==attention_feature) & (feature_df.layer <= start_layer_boundary)].mean_attention_proportion.tolist()
                        transformer_same_relation_corr_str_for_rep = df.sort_values(['model', 'layer'])[(df.model == model_name) & (df.representation==representation) & (df.layer <= start_layer_boundary)][experimental_condition].tolist()
                    if layer_selection == 'end':
                        attention_feature_values = feature_df.sort_values(['model', 'layer'])[(feature_df.model == model_name) & (feature_df.feature==attention_feature) & (feature_df.layer > end_layer_boundary)].mean_attention_proportion.tolist()
                        transformer_same_relation_corr_str_for_rep = df.sort_values(['model', 'layer'])[(df.model == model_name) & (df.representation==representation) & (df.layer > end_layer_boundary)][experimental_condition].tolist()

                    corr, p_val = spearmanr(attention_feature_values, transformer_same_relation_corr_str_for_rep)

                    rows.append({'attention_feature': attention_feature, 'model': model_name, 'representation': representation, 'experimental_condition': experimental_condition, 'corr': corr, 'p_val': p_val, 'layer_selection': layer_selection})

results_df = pd.DataFrame(rows)

Below we see positive correlations between compound attention and the representation of thematic relation within compound groups.

In [None]:
is_self_attention_feature = lambda row: row.attention_feature.split('_')[0] == row.attention_feature.split('_')[1]
results_df['self_attention_feature'] = [is_self_attention_feature(x) for x in results_df.iloc]

In [None]:
new_df = results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr') & (results_df.layer_selection=='all')]
del new_df['experimental_condition']
del new_df['layer_selection']
new_df

In [None]:
models

In [None]:
model_order = np.array(models)

new_df['sort_val'] = [np.argwhere(model_order == x.model)[0][0] for x in new_df.iloc]

In [None]:
sns.catplot(y='corr', x='attention_feature', hue='model', data=new_df.sort_values('sort_val'), kind='bar', col='representation', height=6, aspect=1).set_titles('Correlation between attention features and \nthematic relation signal in "{col_name}" representation')

In [None]:
pd.set_option('display.max_rows', None)
results_df.sort_values('corr')[(results_df.experimental_condition=='same_relation_group_rdm_corr') & ((results_df.attention_feature=='modifier_modifier') | (results_df.attention_feature=='head_head') | (results_df.attention_feature=='compound_compound'))]

In [None]:
new_df = results_df.sort_values('corr', ascending=False)[results_df.experimental_condition=='same_relation_group_rdm_corr_within_compound_sentences']
del new_df['experimental_condition']
del new_df['self_attention_feature']

new_df

In [None]:
plot_attention_over_layers('bert-base-uncased', 37)

In [None]:
plot_attention_over_layers('roberta-base', 37)

In [None]:
plot_attention_over_layers('xlnet-base-cased', 37)

In [None]:
plot_attention_over_layers('xlnet-base-cased', 337)

In [None]:
plot_attention_over_layers('xlnet-base-cased', 637)

In [None]:
plot_attention_over_layers('xlnet-base-cased', 37, layers=[1, 2, 3, 4])

In [None]:
plot_attention_over_layers('distilroberta-base', 37, layers=[1, 2, 3, 4])

In [None]:
plot_attention_over_layers('bert-base-japanese', 43, layers=[1, 6, 12])

In [None]:
results_df['attention_target'] = [x.attention_feature.split('_')[0] for x in results_df.iloc]
results_df['attention_source'] = [x.attention_feature.split('_')[1] for x in results_df.iloc]

In [None]:
results_df.experimental_condition.unique()

In [None]:
rows = [(x[0], x[1].groupby('attention_target')['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr')].groupby('layer_selection')]
# pd.DataFrame([{'layer_selection': x[0], 'mean_compound_corr': x[1]['compound'], 'mean_head_corr': x[1]['head'], 'mean_modifier_corr': x[1]['modifier']} for x in rows])
pd.DataFrame([{'layer_selection': x[0], 'mean_head_corr': x[1]['head'], 'mean_modifier_corr': x[1]['modifier']} for x in rows])


In [None]:
rows = [(x[0], x[1].groupby('attention_source')['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr')].groupby('layer_selection')]
pd.DataFrame([{'layer_selection': x[0], 'mean_head_corr': x[1]['head'], 'mean_modifier_corr': x[1]['modifier']} for x in rows])


In [None]:
results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr') & (results_df.layer_selection=='all')]


In [None]:
results_df.experimental_condition.unique()

In [None]:
for condition in results_df.experimental_condition.unique():
    print(condition)
    print('\t{}'.format([(x[0], x[1]['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition==condition) & (results_df.layer_selection=='all')].groupby('model')]))

In [None]:
[(x[0], x[1]['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr') & (results_df.layer_selection=='all')].groupby('model')]

In [None]:
[(x[0], x[1]['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition=='same_relation_group_rdm_corr_within_compound_sentences') & (results_df.layer_selection=='all')].groupby('model')]

In [None]:


for condition in [x for x in results_df.experimental_condition.unique() if 'within' in x]:
    print(condition)
    print('\t{}'.format([(x[0], x[1]['corr'].mean()) for x in results_df.sort_values('corr', ascending=False)[(results_df.experimental_condition==condition) & (results_df.layer_selection=='all')].groupby('model')]))