In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from attention_graph_util import *
import seaborn as sns
import itertools 
import matplotlib as mpl
import networkx as nx
import os
from util import constants

from absl import app
from absl import flags
import pandas as pd

from util.models import MODELS
from util.tasks import TASKS
#from dnotebook_utils import *
from attention_graph_util import *
%matplotlib inline
from util.config_util import get_task_params
from notebooks.notebook_utils import *
from util import inflect

from tqdm import tqdm

rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0, 
    'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16}
plt.rcParams.update(**rc)
mpl.rcParams['axes.linewidth'] = .5 #set the value globally

import torch
from transformers import *
from transformers import BertConfig, BertForMaskedLM, BertTokenizer

[nltk_data] Downloading package punkt to /home/samira/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
!pip install --upgrade transformers


Requirement already up-to-date: transformers in /home/samira/anaconda3/envs/indist/lib/python3.7/site-packages (2.8.0)


In [3]:
task_name = 'word_sv_agreement_lm'
task_params = get_task_params(batch_size=1)
task = TASKS[task_name](task_params, data_dir='../InDist/data')
cl_token = task.sentence_encoder().encode(constants.bos)
task_tokenizer = task.sentence_encoder()._tokenizer

INFO:absl:Load dataset info from ../InDist/data/word_sv_agreement/0.1.0


Vocab len:  10032


INFO:absl:Constructing tf.data.Dataset for split validation, from ../InDist/data/word_sv_agreement/0.1.0
INFO:absl:Constructing tf.data.Dataset for split test, from ../InDist/data/word_sv_agreement/0.1.0
INFO:absl:Constructing tf.data.Dataset for split train, from ../InDist/data/word_sv_agreement/0.1.0


In [4]:
# Transformers has a unified API
# for 8 transformer architectures and 30 pretrained weights.
#          Model          | Tokenizer          | Pretrained weights shortcut
MODELS = [(BertModel,       BertTokenizer,       'bert-base-uncased'),
          (OpenAIGPTModel,  OpenAIGPTTokenizer,  'openai-gpt'),
          (GPT2Model,       GPT2Tokenizer,       'gpt2'),
          (CTRLModel,       CTRLTokenizer,       'ctrl'),
          (TransfoXLModel,  TransfoXLTokenizer,  'transfo-xl-wt103'),
          (XLNetModel,      XLNetTokenizer,      'xlnet-base-cased'),
          (XLMModel,        XLMTokenizer,        'xlm-mlm-enfr-1024'),
          (DistilBertModel, DistilBertTokenizer, 'distilbert-base-uncased'),
          (RobertaModel,    RobertaTokenizer,    'roberta-base')]

# Each architecture is provided with several class for fine-tuning on down-stream tasks, e.g.
BERT_MODEL_CLASSES = [BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,
                      BertForSequenceClassification, BertForTokenClassification, BertForQuestionAnswering]

# All the classes for an architecture can be initiated from pretrained weights for this architecture
# Note that additional weights added for fine-tuning are only initialized
# and need to be trained on the down-stream task
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

In [5]:
model = BertForMaskedLM.from_pretrained(pretrained_weights,
                                  output_hidden_states=True,
                                  output_attentions=True)

In [6]:
def offset_convertor(encoded_input_task, task_offset, task_encoder, tokenizer):
    string_part1 = task_encoder.decode(encoded_input_task[:task_offset])
    tokens_part1 = tokenizer.tokenize(string_part1)
    
    return len(tokens_part1)


In [7]:
for x,y in task.test_dataset:
    sentence = task.sentence_encoder().decode(x[0][1:])
    print(sentence)
    break

tokens = ['cls']+tokenizer.tokenize(sentence)+['sep']
print(len(tokens), tokens)
tf_input_ids = tokenizer.encode(sentence)
input_ids = torch.tensor([tf_input_ids])
all_hidden_states, all_attentions = model(input_ids)[-2:]

_attentions = [att.detach().numpy() for att in all_attentions]
attentions_mat = np.asarray(_attentions)[:,0]
print(attentions_mat.shape)

many NNS of woodland remain and support a JJ sector in the southern portion of the state .
21 ['cls', 'many', 'n', '##ns', 'of', 'woodland', 'remain', 'and', 'support', 'a', 'jj', 'sector', 'in', 'the', 'southern', 'portion', 'of', 'the', 'state', '.', 'sep']
(12, 12, 21, 21)


In [None]:
all_examples_x = []
all_examples_y = []
all_examples_attentions = []
all_examples_blankout_relevance = []

n_batches = 1000


infl_eng = inflect.engine()
verb_infl, noun_infl = gen_inflect_from_vocab(infl_eng, '../InDist/notebooks/wiki.vocab')

test_data = task.databuilder.as_dataset(split='validation', batch_size=1)
for examples in tqdm(test_data):
    sentence = task.sentence_encoder().decode(examples['sentence'][0])
    
    verb_position = examples['verb_position'][0].numpy()+1  #+1 because of adding cls.
    verb_position = offset_convertor(examples['sentence'][0], verb_position, task.sentence_encoder(), tokenizer)
    
    sentence = ['cls']+tokenizer.tokenize(sentence)+['sep']
    
    
    sentence[verb_position] = tokenizer.mask_token
    tf_input_ids = tokenizer.encode(sentence)
    input_ids = torch.tensor([tf_input_ids])
    

    
    s_shape = input_ids.shape
    batch_size, length = s_shape[0], s_shape[1]
    actual_verb = examples['verb'][0].numpy().decode("utf-8")
    inflected_verb = verb_infl[actual_verb] 


    actual_verb_index = tokenizer.encode(tokenizer.tokenize(actual_verb))[1]
    inflected_verb_index = tokenizer.encode(tokenizer.tokenize(inflected_verb))[1]

    all_examples_x.append(input_ids)
    predictions = model(input_ids)
    logits = predictions[0][0]
    probs = torch.nn.Softmax(dim=-1)(logits)
    actual_verb_score = probs[verb_position][actual_verb_index]
    inflected_verb_score = probs[verb_position][inflected_verb_index]
    
    main_diff_score = actual_verb_score - inflected_verb_score
    hidden_states, attentions = predictions[-2:]
    _attentions = [att.detach().numpy() for att in attentions]
    attentions_mat = np.asarray(_attentions)[:,0]

    all_examples_attentions.append(attentions_mat)
    
    # Repeating examples and replacing one token at a time with unk
    batch_size = 1
    max_len = input_ids.shape[1]
    
    # Repeat each example 'max_len' times
    x = input_ids
    extended_x = np.reshape(np.tile(x[:,None,...], (1, max_len, 1)),(-1,x.shape[-1]))
    #extended_y = np.reshape(np.tile(y[:,None],(1,max_len)),(-1,))
    #extened_correct_main_probs = np.reshape(np.tile(correct_main_probs[:,None],(1,max_len)),(-1,))
    
    # Create unk sequences and unk mask
    unktoken = tokenizer.encode([tokenizer.mask_token])[1]
    unks = unktoken * np.eye(max_len)
    unks =  np.tile(unks, (batch_size, 1))
    
    unk_mask =  (unktoken - unks)/unktoken
  
    # Replace one token in each repeatition with unk
    extended_x = extended_x * unk_mask + unks
    
    # Get the new output
    extended_predictions = model(torch.tensor(extended_x, dtype=torch.int64))
    extended_logits = extended_predictions[0]
    extended_probs = torch.nn.Softmax(dim=-1)(extended_logits)
    
    extended_correct_probs = extended_probs[:,verb_position,actual_verb_index]
    extended_wrong_probs =  extended_probs[:,verb_position,inflected_verb_index]
    extended_diff_scores = extended_correct_probs - extended_wrong_probs
    
    # Save the difference in the probability predicted for the correct class
    diffs = abs(main_diff_score - extended_diff_scores)

    all_examples_blankout_relevance.append(diffs.detach())
    n_batches -= 1
    if n_batches <= 0:
        break




INFO:absl:Constructing tf.data.Dataset for split validation, from ../InDist/data/word_sv_agreement/0.1.0
552it [09:51,  1.47s/it]

In [None]:
print(all_examples_raw_relevance[0][0])
print(all_examples_joint_relevance[0][0])
print(all_examples_blankout_relevance[0])

In [None]:
all_examples_attentions[0]

In [None]:
    
all_examples_raw_relevance = {}
for l in np.arange(6):
    all_examples_raw_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = tokenizer.decode(all_examples_x[i][0].numpy())
        length = len(tokens)
        attention_relevance = get_raw_att_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
        all_examples_raw_relevance[l].append(np.asarray(attention_relevance))


all_examples_joint_relevance = {}
for l in np.arange(6):
    all_examples_joint_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = tokenizer.decode(all_examples_x[i][0].numpy())
        length = len(tokens)
        attention_relevance = get_joint_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
        all_examples_joint_relevance[l].append(np.asarray(attention_relevance))
    

# all_examples_flow_relevance = {}
# for l in np.arange(6):
#     all_examples_flow_relevance[l] = []
#     for i in tqdm(np.arange(len(all_examples_x))):
#         tokens = tokenizer.decode(all_examples_x[i][0].numpy())
#         length = len(tokens)
#         attention_relevance = get_flow_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
#         all_examples_flow_relevance[l].append(np.asarray(attention_relevance))

In [None]:
from scipy.stats import spearmanr
# print(np.mean([spearmanr(all_examples_flow_relevance[i], all_examples_joint_relevance[i]) for i in np.arange(len(all_examples_x))]))
# print(np.mean([spearmanr(all_examples_flow_relevance[i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))

for l in np.arange(6):
    print("layer ",l)
    print(all_examples_blankout_relevance[l][0].shape, all_examples_raw_relevance[l][0].shape, all_examples_joint_relevance[l][0].shape)
    print('raw:',np.mean([spearmanr(all_examples_raw_relevance[l][i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))
    print('joint',np.mean([spearmanr(all_examples_joint_relevance[l][i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))
    #print('flow',np.mean([spearmanr(all_examples_flow_relevance[l][i], all_examples_blankout_relevance[l][i]) for i in np.arange(len(all_examples_x))]))

In [None]:
def spearmanr(x, y):
    """ `x`, `y` --> pd.Series"""
    x = pd.Series(x)
    y = pd.Series(y)
    assert x.shape == y.shape
    rx = x.rank(method='dense')
    ry = y.rank(method='dense')
    d = rx - ry
    dsq = np.sum(np.square(d))
    n = x.shape[0]
    coef = 1. - (6. * dsq) / (n * (n**2 - 1.))
    return coef

def get_raw_att_relevance(full_att_mat, input_tokens, layer=-1):
    cls_index = 0
    raw_rel = full_att_mat[layer].sum(axis=0)[cls_index]/full_att_mat[layer].sum(axis=0)[cls_index].sum()
    
    return raw_rel


def get_joint_relevance(full_att_mat, input_tokens, layer=-1):
    att_sum_heads =  full_att_mat.sum(axis=1)/8
    joint_attentions = compute_joint_attention(att_sum_heads, add_residual=True)
    relevance_attentions = joint_attentions[layer][0]
    return relevance_attentions


def get_flow_relevance(full_att_mat, input_tokens, layer):
    
    input_tokens = input_tokens
    res_att_mat = full_att_mat.sum(axis=1)/full_att_mat.shape[1]
    res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
    res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]

    res_adj_mat, res_labels_to_index = get_adjmat(mat=res_att_mat, input_tokens=input_tokens)
    
    A = res_adj_mat
    res_G=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
    for i in np.arange(A.shape[0]):
        for j in np.arange(A.shape[1]):
            nx.set_edge_attributes(res_G, {(i,j): A[i,j]}, 'capacity')


    output_nodes = []
    input_nodes = []
    for key in res_labels_to_index:
        if 'L'+str(layer+1) in key:
            output_nodes.append(key)
        if res_labels_to_index[key] < full_att_mat.shape[-1]:
            input_nodes.append(key)
    
    flow_values = compute_node_flow(res_G, res_labels_to_index, input_nodes, output_nodes, length=full_att_mat.shape[-1])
    
    n_layers = full_att_mat.shape[0]
    length = full_att_mat.shape[-1]
    final_layer_attention = flow_values[(layer+1)*length:,layer*length:(layer+1)*length]
    cls_index = 0
    relevance_attention_raw = final_layer_attention[cls_index]

    return relevance_attention_raw