In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from attention_graph_util import *
import seaborn as sns
import itertools 
import matplotlib as mpl
import networkx as nx
import os
from util import constants

from absl import app
from absl import flags
import pandas as pd

from util.models import MODELS
from util.tasks import TASKS
#from dnotebook_utils import *
from attention_graph_util import *
%matplotlib inline
from util.config_util import get_task_params
from notebooks.notebook_utils import *
from util import inflect

from tqdm import tqdm
from scipy.stats import spearmanr
import math


rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0, 
    'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16}
plt.rcParams.update(**rc)
mpl.rcParams['axes.linewidth'] = .5 #set the value globally

import torch
from transformers import *
from transformers import BertConfig, BertForMaskedLM, BertTokenizer

[nltk_data] Downloading package punkt to /home/dehghani/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
!pip install --upgrade transformers
!pip install networkx
!pip install --upgrade matplotlib
!pip install --upgrade seaborn


!pip install torch torchvision

In [2]:
task_name = 'word_sv_agreement_lm'
task_params = get_task_params(batch_size=1)
task = TASKS[task_name](task_params, data_dir='../InDist/data')
cl_token = task.sentence_encoder().encode(constants.bos)
task_tokenizer = task.sentence_encoder()._tokenizer

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Constructing tf.data.Dataset for split validation, from ../InDist/data/word_sv_agreement/0.1.0


Vocab len:  10032


INFO:absl:Constructing tf.data.Dataset for split test, from ../InDist/data/word_sv_agreement/0.1.0
INFO:absl:Constructing tf.data.Dataset for split train, from ../InDist/data/word_sv_agreement/0.1.0


In [3]:
from transformers import DistilBertTokenizer, DistilBertModel
import torch

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased',
                                        output_hidden_states=True,
                                        output_attentions=True)



HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




In [4]:
def offset_convertor(encoded_input_task, task_offset, task_encoder, tokenizer):
    string_part1 = task_encoder.decode(encoded_input_task[:task_offset])
    tokens_part1 = tokenizer.tokenize(string_part1)
    
    return len(tokens_part1)


In [5]:
for x,y in task.test_dataset:
    sentence = task.sentence_encoder().decode(x[0][1:])
    print(sentence)
    break

tokens = ['cls']+tokenizer.tokenize(sentence)+['sep']
print(len(tokens), tokens)
tf_input_ids = tokenizer.encode(sentence)
input_ids = torch.tensor([tf_input_ids])
logits, all_hidden_states, all_attentions = model(input_ids)
print(logits.shape)
_attentions = [att.detach().numpy() for att in all_attentions]
attentions_mat = np.asarray(_attentions)[:,0]
print(attentions_mat.shape)

embeded_inputs = torch.autograd.Variable(model.distilbert.embeddings(input_ids), requires_grad=True)
logits, all_hidden_states, all_attentions = model(inputs_embeds=embeded_inputs)
print(embeded_inputs.shape)


lsum = logits.sum()
print(lsum)

lsum.backward()
embeded_inputs.require_grad = True
print(embeded_inputs.grad.shape)

many NNS of woodland remain and support a JJ sector in the southern portion of the state .
21 ['cls', 'many', 'n', '##ns', 'of', 'woodland', 'remain', 'and', 'support', 'a', 'jj', 'sector', 'in', 'the', 'southern', 'portion', 'of', 'the', 'state', '.', 'sep']
torch.Size([1, 21, 30522])
(6, 12, 21, 21)
torch.Size([1, 21, 768])
tensor(-3811819.5000, grad_fn=<SumBackward0>)
torch.Size([1, 21, 768])


In [76]:
all_examples_x = []
all_examples_vp = []
all_examples_y = []

all_examples_attentions = []
all_examples_blankout_relevance = []
all_examples_grads = []
all_examples_inputgrads = []
n_batches = 10

all_examples_accuracies = []

infl_eng = inflect.engine()
verb_infl, noun_infl = gen_inflect_from_vocab(infl_eng, '../InDist/notebooks/wiki.vocab')

test_data = task.databuilder.as_dataset(split='validation', batch_size=1)
for examples in tqdm(test_data):
    sentence = task.sentence_encoder().decode(examples['sentence'][0][1:])
    if len(examples['sentence'][0][1:]) > 20:
        continue
    
    verb_position = examples['verb_position'][0].numpy()-1  #+1 because of adding cls.
    verb_position = offset_convertor(examples['sentence'][0], verb_position, task.sentence_encoder(), tokenizer)
    
    sentence = tokenizer.tokenize(sentence)
    
    all_examples_vp.append(verb_position)
    sentence[verb_position] = tokenizer.mask_token
    
    tf_input_ids = tokenizer.encode(sentence)
    input_ids = torch.tensor([tf_input_ids])
    sentence = tokenizer.tokenize(tokenizer.decode(tf_input_ids))
    print(sentence)

    
    s_shape = input_ids.shape
    batch_size, length = s_shape[0], s_shape[1]
    actual_verb = examples['verb'][0].numpy().decode("utf-8")
    inflected_verb = verb_infl[actual_verb] 


    actual_verb_index = tokenizer.encode(tokenizer.tokenize(actual_verb))[1]
    inflected_verb_index = tokenizer.encode(tokenizer.tokenize(inflected_verb))[1]

    all_examples_x.append(input_ids)
    embeded_inputs = torch.autograd.Variable(model.distilbert.embeddings(input_ids), requires_grad=True)
    predictions = model(inputs_embeds=embeded_inputs)
    logits = predictions[0][0]

    
        
    probs = torch.nn.Softmax(dim=-1)(logits)
    actual_verb_score = probs[verb_position][actual_verb_index]
    inflected_verb_score = probs[verb_position][inflected_verb_index]
    
    main_diff_score = actual_verb_score - inflected_verb_score
    
    all_examples_accuracies.append(main_diff_score > 0)
    
    logits_sum = logits.sum()
    actual_verb_score.backward()
    grads = embeded_inputs.grad
    grad_scores = abs(np.sum(grads.detach().numpy(), axis=-1))
    input_grad_scores = abs(np.sum((grads * embeded_inputs).detach().numpy(), axis=-1))
    all_examples_grads.append(grad_scores)
    all_examples_inputgrads.append(input_grad_scores)
    
    hidden_states, attentions = predictions[-2:]
    _attentions = [att.detach().numpy() for att in attentions]
    attentions_mat = np.asarray(_attentions)[:,0]

    all_examples_attentions.append(attentions_mat)
    
    n_batches -= 1
    if n_batches <= 0:
        break




INFO:absl:Constructing tf.data.Dataset for split validation, from ../InDist/data/word_sv_agreement/0.1.0


0it [00:00, ?it/s][A[A

1it [00:00,  5.31it/s][A[A

['[CLS]', 'ready', 'to', 'serve', 'at', 'every', 'opportunity', ',', 'yet', 'making', 'sure', 'that', 'your', 'fellow', 'servers', '[MASK]', 'an', 'equal', 'chance', '.', '[SEP]']
['[CLS]', 'reviewed', 'journals', '[MASK]', 'of', 'varying', 'degrees', 'of', 'reliability', '.', '[SEP]']




3it [00:00,  6.01it/s][A[A

4it [00:00,  5.60it/s][A[A

['[CLS]', 'operation', 'since', '1871', ',', 'the', 'network', '[MASK]', 'presently', 'about', 'long', ',', 'and', 'comprises', '10', 'lines', '.', '[SEP]']
['[CLS]', 'peak', 'times', ',', 'the', 'n', '##np', 'route', '[MASK]', 'via', 'the', 'n', '##np', 'guided', 'n', '##np', 'to', 'cambridge', '.', '[SEP]']




5it [00:00,  5.18it/s][A[A

8it [00:01,  6.47it/s][A[A

['[CLS]', 'first', 'requirement', '(', 'n', '##np', ')', 'simply', '[MASK]', 'that', 'a', 'cd', 'should', 'be', 'a', 'distribution', 'on', 'the', 'parameter', 'space', '.', '[SEP]']
['[CLS]', 'letters', '[MASK]', 'small', 'and', 'jj', '.', '[SEP]']




10it [00:01,  6.80it/s][A[A

['[CLS]', 'women', 'the', 'number', '[MASK]', 'one', 'in', 'forty', 'and', 'the', 'n', '##ns', 'are', 'more', 'likely', 'to', 'be', 'prison', 'staff', 'members', '.', '[SEP]']
['[CLS]', 'n', '##n', 'link', '[MASK]', 'because', 'the', 'company', 'uses', 'n', '##n', 'products', '.', '[SEP]']




12it [00:01,  7.63it/s][A[A

13it [00:01,  7.06it/s][A[A

['[CLS]', 'support', 'he', '[MASK]', 'a', 'fine', 'editor', ',', 'but', 'has', 'too', 'little', 'edit', '##s', '.', '[SEP]']
['[CLS]', 'method', '[MASK]', 'the', 'magnetic', 'n', '##n', 'that', 'the', 'n', '##n', 'experiences', ',', 'constant', 'over', 'the', 'n', '##n', "'", 's', 'normal', 'n', '##n', 'range', '.', '[SEP]']


13it [00:01,  7.00it/s]


In [77]:
def get_raw_att_relevance(full_att_mat, input_tokens, layer=-1, output_index=0):
    raw_rel = full_att_mat[layer].sum(axis=0)[output_index]/full_att_mat[layer].sum(axis=0)[output_index].sum()
    
    return raw_rel


def get_joint_relevance(full_att_mat, input_tokens, layer=-1, output_index=0):
    att_sum_heads =  full_att_mat.sum(axis=1) / full_att_mat.shape[1]
    joint_attentions = compute_joint_attention(att_sum_heads, add_residual=True)
    relevance_attentions = joint_attentions[layer][output_index]
    return relevance_attentions


def get_flow_relevance(full_att_mat, input_tokens, layer, output_index):
    
    input_tokens = input_tokens
    res_att_mat = full_att_mat.sum(axis=1)/full_att_mat.shape[1]
    res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
    res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]

    res_adj_mat, res_labels_to_index = get_adjmat(mat=res_att_mat, input_tokens=input_tokens)
    
    A = res_adj_mat
    res_G=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
    for i in np.arange(A.shape[0]):
        for j in np.arange(A.shape[1]):
            nx.set_edge_attributes(res_G, {(i,j): A[i,j]}, 'capacity')


    output_nodes = ['L'+str(layer+1)+'_'+str(output_index)]
    input_nodes = []
    for key in res_labels_to_index:
        if res_labels_to_index[key] < full_att_mat.shape[-1]:
            input_nodes.append(key)
    
    flow_values = compute_node_flow(res_G, res_labels_to_index, input_nodes=input_nodes, output_nodes=output_nodes, length=full_att_mat.shape[-1])
    
    n_layers = full_att_mat.shape[0]
    length = full_att_mat.shape[-1]
    final_layer_attention = flow_values[(layer+1)*length:, layer*length:(layer+1)*length]
    relevance_attention_flow = final_layer_attention[output_index]

    return relevance_attention_flow

In [78]:
all_examples_attentions[0].shape

(6, 12, 21, 21)

In [79]:
print("compute raw relevance scores ...")
all_examples_raw_relevance = {}
for l in np.arange(5,6):
    all_examples_raw_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = tokenizer.tokenize(tokenizer.decode(all_examples_x[i][0].numpy()))
        vp = all_examples_vp[i]
        length = len(tokens)
        attention_relevance = get_raw_att_relevance(all_examples_attentions[i], tokens, layer=l, output_index=vp)
        all_examples_raw_relevance[l].append(np.asarray(attention_relevance))

print("compute joint relevance scores ...")
all_examples_joint_relevance = {}
for l in np.arange(5,6):
    all_examples_joint_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = tokenizer.tokenize(tokenizer.decode(all_examples_x[i][0].numpy()))
        vp = all_examples_vp[i]
        length = len(tokens)
        attention_relevance = get_joint_relevance(all_examples_attentions[i], tokens, layer=l, output_index=vp)
        all_examples_joint_relevance[l].append(np.asarray(attention_relevance))
    
print("compute flow relevance scores ...")
all_examples_flow_relevance = {}
for l in np.arange(5,6):
    all_examples_flow_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = tokenizer.tokenize(tokenizer.decode(all_examples_x[i][0].numpy()))
        vp = all_examples_vp[i]
        length = len(tokens)
        attention_relevance = get_flow_relevance(all_examples_attentions[i], tokens, layer=l, output_index=vp)
        all_examples_flow_relevance[l].append(np.asarray(attention_relevance))



100%|██████████| 10/10 [00:00<00:00, 2236.84it/s]


100%|██████████| 10/10 [00:00<00:00, 1624.63it/s]


  0%|          | 0/10 [00:00<?, ?it/s][A[A

compute raw relevance scores ...
compute joint relevance scores ...
compute flow relevance scores ...




 10%|█         | 1/10 [00:03<00:32,  3.65s/it][A[A

 20%|██        | 2/10 [00:04<00:21,  2.67s/it][A[A

 30%|███       | 3/10 [00:05<00:16,  2.42s/it][A[A

 40%|████      | 4/10 [00:08<00:15,  2.56s/it][A[A

 50%|█████     | 5/10 [00:13<00:15,  3.15s/it][A[A

 70%|███████   | 7/10 [00:17<00:08,  2.81s/it][A[A

 80%|████████  | 8/10 [00:18<00:04,  2.21s/it][A[A

 90%|█████████ | 9/10 [00:19<00:01,  1.89s/it][A[A

100%|██████████| 10/10 [00:27<00:00,  2.78s/it][A[A


In [80]:
for l in np.arange(5,6):
    print("###############Layer ",l, "#############")

    print('raw grad')
    print(all_examples_raw_relevance[l][0].shape, all_examples_grads[0][0].shape)
    raw_sps_grad = []
    for i in np.arange(len(all_examples_x)):
        sp = spearmanr(all_examples_raw_relevance[l][i],all_examples_grads[i][0])
        if not math.isnan(sp[0]):
            raw_sps_grad.append(sp[0])
        else:
            raw_sps_grad.append(0)
        
    print(np.mean(raw_sps_grad), np.std(raw_sps_grad))

    
    print('joint grad')
    print(all_examples_joint_relevance[l][0].shape, all_examples_grads[0][0].shape)
    joint_sps_grad = []
    for i in np.arange(len(all_examples_x)):
        sp = spearmanr(all_examples_joint_relevance[l][i],all_examples_grads[i][0])
        if not math.isnan(sp[0]):
            joint_sps_grad.append(sp[0])
        else:
            joint_sps_grad.append(0)
        
    print(np.mean(joint_sps_grad), np.std(joint_sps_grad))

  
    print('flow grad')
    print(all_examples_joint_relevance[l][0].shape, all_examples_grads[0][0].shape)
    flow_sps_grad = []
    for i in np.arange(len(all_examples_x)):
        sp = spearmanr(all_examples_flow_relevance[l][i],all_examples_grads[i][0])
        if not math.isnan(sp[0]):
            flow_sps_grad.append(sp[0])
        else:
            flow_sps_grad.append(0)
        
    print(np.mean(flow_sps_grad))

###############Layer  5 #############
raw grad
(21,) (21,)
0.15722765381669673 0.3068715774066908
joint grad
(21,) (21,)
0.2847923704897725 0.26784481005012173
flow grad
(21,) (21,)
0.278563534317075


In [None]:
print(all_examples_joint_relevance[l][0].shape)
print(all_examples_flow_relevance[l][0].shape)
print(all_examples_blankout_relevance[0].numpy().shape)
print(all_examples_inputgrads[0][0].shape)

In [None]:
all_examples_inputgrads[0][0]

In [None]:
model_1 = model

sentences = []
all_atts = []
all_main_probs = []
all_index_probs = []
all_gradient_scores = []
all_inputgradient_scores = []
prob_fn = task.get_probs_fn()
count = 0
for x, y in task.test_dataset:
    
     #Manually add cls token:
    batch_size = len(x)
    cl_token = tf.reshape(tf.convert_to_tensor(cl_token[0], dtype=tf.int64)[None], (-1,1))
    cl_tokens = tf.tile(cl_token, (batch_size, 1))
    x = tf.concat([cl_tokens, x], axis=-1)
    
    # Get gradient scores 
    input_embeddings, input_shape, padding_mask, past = model.get_input_embeddings(x, training=False, add_cls=False)
    with tf.GradientTape() as tape:
        tape.watch(input_embeddings)
        outputs = model_1.call_with_embeddings(input_embeddings, input_shape, padding_mask, past)
        logits = outputs[0]
        probs = tf.nn.softmax(logits, axis=-1)
        diff_probs = probs[:,0] - probs[:,1]
        
    grads = tape.gradient(diff_probs, input_embeddings)
    grad_scores = tf.abs(tf.reduce_sum(grads, axis=-1))
    input_grad_scores = tf.abs(tf.reduce_sum(tf.multiply(grads, input_embeddings), axis=-1))
    
    
    all_gradient_scores.extend(grad_scores)
    all_inputgradient_scores.extend(input_grad_scores)
    
    
    max_len = x.shape[1]
    all_outputs = model_1.detailed_call(x, training=False, add_cls=False)
    main_logits = all_outputs[0]
    attentions = all_outputs[6]
    _attentions = [att.numpy() for att in attentions]
    attentions = np.transpose(np.asarray(_attentions), (1,0,2,3,4))
    main_probs = prob_fn(main_logits, y, 1)
    batch_indexes = tf.range(len(y), dtype=tf.int64)
    indexes = tf.concat([batch_indexes[:,None], y[:,None]], axis=1)
    correct_main_probs = tf.gather_nd(main_probs, indexes).numpy()

    sentences.append(task.databuilder.sentence_encoder().decode(x[0]))
    all_atts.extend(attentions)
    all_main_probs.extend(correct_main_probs)
    all_index_probs.append([])
    
    # This loop can be optimized so that there is only one call...
    new_xz = []
    for i in np.arange(0,max_len):
        batch_size = tf.shape(x)[0]
        unktoken = task.databuilder.sentence_encoder().encode(constants.unk)
        unk = tf.reshape(tf.convert_to_tensor(unktoken, dtype=tf.int64)[None], (-1,1))
        unks = tf.tile(unk, (batch_size, 1))
        new_x = tf.concat([x[:,:i], unks, x[:,i+1:]], axis=-1)
        new_xz.extend(new_x)
    
    new_x = np.asarray(new_xz)
    logits = model_1(new_x, training=False, add_cls=False)
    probs = prob_fn(logits, y, 1)
    
    batch_indexes = tf.range(len(probs), dtype=tf.int64)
    yz = tf.tile(y, (len(probs),))

    indexes = tf.concat([batch_indexes[:,None], yz[:,None]], axis=1)
    
    correct_probs = tf.gather_nd(probs, indexes).numpy()
    all_index_probs[-1].extend(abs(correct_main_probs - correct_probs))
    count += 1
    if count > 100:
        break
    print (count, end="\r")