In [None]:
import networkx as nx
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from util import constants
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.trainer import Trainer
from absl import app
from absl import flags
import pandas as pd

from util.models import MODELS
from util.tasks import TASKS
from notebook_utils import *
from attention_graph_util import *
%matplotlib inline


import matplotlib as mpl

In [2]:
# Load Task: VP

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    task_name = 'word_sv_agreement_vp'
    chkpt_dir='../tf_ckpts'
    task_params = get_task_params(batch_size=10)
    task = TASKS[task_name](task_params, data_dir='../data')
    cl_token = task.sentence_encoder().encode(constants.bos)
    tokenizer = task.sentence_encoder()._tokenizer




Vocab len:  10032


In [6]:
#Load and evaluate a model

config = {'student_exp_name':'af_std5',
        'teacher_exp_name':'af_tchr5',
        'teacher_config':'small_lstm_v4',
        'student_model':'cl_bert',
        'teacher_model':'cl_lstm',
        'student_config':'small_gpt_v9',
        'distill_config':'dstl_6_crs_slw',
        'distill_mode':'online',
        'chkpt_dir':'../tf_ckpts',}

hparams=get_model_params(task, config['student_model'], config['student_config'])    
hparams.output_attentions = True
hparams.output_embeddings = True
hparams.output_hidden_states = True

with strategy.scope():
    model, ckpnt = get_student_model(config, task, hparams, cl_token)


model.evaluate(task.valid_dataset, steps=100)

model config: small_gpt_v9
model config: small_lstm_v4
student_checkpoint: ../tf_ckpts/word_sv_agreement_vp/online_dstl_6_crs_slw_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_af_tchr5_student_cl_bert_h-128_d-6_rdrop-0.4_adrop-0.6_indrop-0.2_small_gpt_v9_af_std5
Restored student from ../tf_ckpts/word_sv_agreement_vp/online_dstl_6_crs_slw_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_af_tchr5_student_cl_bert_h-128_d-6_rdrop-0.4_adrop-0.6_indrop-0.2_small_gpt_v9_af_std5/ckpt-5


[0.20601512948051096, 0.18794924, 0.929]

In [7]:
all_examples_x = []
all_examples_y = []
all_examples_attentions = []
all_examples_correct_probs = []
all_examples_correct_index_probs_diff = []
non_reshaped = []
n_batches = 10
prob_fn = task.get_probs_fn()
for x, y in iter(task.test_dataset):
    
    #Manually add cls token:
    batch_size = len(x)
    cl_token = tf.reshape(tf.convert_to_tensor(cl_token[0], dtype=tf.int64)[None], (-1,1))
    cl_tokens = tf.tile(cl_token, (batch_size, 1))
    x = tf.concat([cl_tokens, x], axis=-1)
    
    # Save examples
    all_examples_x.extend(x)
    all_examples_y.extend(y)
    
    # Call the model to the get the logits and attentions
    outputs = model.detailed_call(x, add_cls=False, training=False)
    main_logits = outputs[0]
    
    # Get the probability of the correct class
    main_probs = prob_fn(main_logits, y, 1)
    batch_indexes = tf.range(len(y), dtype=tf.int64)
    indexes = tf.concat([batch_indexes[:,None], y[:,None]], axis=1)
    correct_main_probs = tf.gather_nd(main_probs, indexes).numpy()
    
    attentions_of_all_layers = outputs[6]
    
    # Reshape the attention matrix to: [batchsize, layers, heads, length, length]
    attentions_of_all_layers = [att.numpy() for att in attentions_of_all_layers]
    attentions_of_all_layers = np.transpose(np.asarray(attentions_of_all_layers), (1,0,2,3,4))
    
    # Save attentions and correct probs
    all_examples_attentions.extend(attentions_of_all_layers)
    all_examples_correct_probs.extend(correct_main_probs)

    
    # Repeating examples and replacing one token at a time with unk
    batch_size = tf.shape(x)[0]
    max_len = x.shape[1]
    
    # Repeat each example 'max_len' times
    extended_x = tf.reshape(tf.tile(x[:,None,...], (1,max_len, 1)),(-1,x.shape[-1]))
    extended_y = tf.reshape(tf.tile(y[:,None],(1,max_len)),(-1,))
    extened_correct_main_probs = tf.reshape(tf.tile(correct_main_probs[:,None],(1,max_len)),(-1,))
    
    # Create unk sequences and unk mask
    unktoken = task.databuilder.sentence_encoder().encode(constants.unk)
    unks = unktoken * tf.eye(max_len)
    unks = tf.cast(tf.tile(unks, (batch_size, 1)), dtype=tf.int64)
    unk_mask =  tf.cast((unktoken - unks)/unktoken, dtype=tf.int64)
  
    # Replace one token in each repeatition with unk
    extended_x = extended_x * unk_mask + unks
    
    # Get the new output
    extended_logits = model(extended_x, training=False)
    extended_probs = prob_fn(extended_logits, extended_y, 1)
    batch_indexes = tf.range(len(extended_y), dtype=tf.int64)
    extended_indexes = tf.concat([batch_indexes[:,None], extended_y[:,None]], axis=1)
    extended_correct_probs = tf.gather_nd(extended_probs, extended_indexes).numpy()
    
    # Save the difference in the probability predicted for the correct class
    diffs = abs(extened_correct_main_probs - extended_correct_probs)
    diffs = tf.reshape(diffs,(batch_size,-1,1))
    all_examples_correct_index_probs_diff.extend(diffs)
    
    
    n_batches -= 1
    if n_batches <= 0:
        break
        









































































































































































































































































































































































































































































































































































































































































In [8]:
print(all_examples_x[0].shape)
print(all_examples_correct_index_probs_diff[0].shape)
print(all_examples_attentions[0].shape)

(13,)
(13, 1)
(6, 8, 13, 13)


In [None]:
print(all_examples_x[5])

In [9]:
def spearmanr(x, y):
    """ `x`, `y` --> pd.Series"""
    x = pd.Series(x)
    y = pd.Series(y)
    assert x.shape == y.shape
    rx = x.rank(method='dense')
    ry = y.rank(method='dense')
    d = rx - ry
    dsq = np.sum(np.square(d))
    n = x.shape[0]
    coef = 1. - (6. * dsq) / (n * (n**2 - 1.))
    return coef

def get_raw_att_relevance(full_att_mat, input_tokens, layer=-1):
    cls_index = 0
    raw_rel = full_att_mat[layer].sum(axis=0)[cls_index]/full_att_mat[layer].sum(axis=0)[cls_index].sum()
    
    return raw_rel


def get_joint_relevance(full_att_mat, input_tokens, layer=-1):
    att_sum_heads =  full_att_mat.sum(axis=1)/8
    joint_attentions = compute_joint_attention(att_sum_heads, add_residual=True)
    relevance_attentions = joint_attentions[layer][0]
    return relevance_attentions


def get_flow_relevance(full_att_mat, input_tokens, layer):
    
    input_tokens = input_tokens
    res_att_mat = full_att_mat.sum(axis=1)/full_att_mat.shape[1]
    res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
    res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]

    res_adj_mat, res_labels_to_index = get_adjmat(mat=res_att_mat, input_tokens=input_tokens)
    
    A = res_adj_mat
    res_G=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
    for i in np.arange(A.shape[0]):
        for j in np.arange(A.shape[1]):
            nx.set_edge_attributes(res_G, {(i,j): A[i,j]}, 'capacity')


    output_nodes = []
    input_nodes = []
    for key in res_labels_to_index:
        if 'L'+str(layer+1) in key:
            output_nodes.append(key)
        if res_labels_to_index[key] < full_att_mat.shape[-1]:
            input_nodes.append(key)
    
    flow_values = compute_node_flow(res_G, res_labels_to_index, input_nodes, output_nodes, length=full_att_mat.shape[-1])
    
    n_layers = full_att_mat.shape[0]
    length = full_att_mat.shape[-1]
    final_layer_attention = flow_values[(layer+1)*length:,layer*length:(layer+1)*length]
    cls_index = 0
    relevance_attention_raw = final_layer_attention[cls_index]

    return relevance_attention_raw

In [10]:
    
all_examples_raw_relevance = {}
for l in np.arange(6):
    all_examples_raw_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = task.sentence_encoder().decode(all_examples_x[i]).split()
        length = len(tokens)
        attention_relevance = get_raw_att_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
        all_examples_raw_relevance[l].append(np.asarray(attention_relevance))
        
all_examples_blankout_relevance = {}
for l in np.arange(6):
    all_examples_blankout_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = task.sentence_encoder().decode(all_examples_x[i]).split()
        length = len(tokens)
        all_examples_blankout_relevance[l].append(all_examples_correct_index_probs_diff[i][...,:length, :length].numpy().squeeze())

all_examples_joint_relevance = {}
for l in np.arange(6):
    all_examples_joint_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = task.sentence_encoder().decode(all_examples_x[i]).split()
        length = len(tokens)
        attention_relevance = get_joint_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
        all_examples_joint_relevance[l].append(np.asarray(attention_relevance))
    


100%|██████████| 100/100 [00:00<00:00, 131.99it/s]
100%|██████████| 100/100 [00:00<00:00, 135.23it/s]
100%|██████████| 100/100 [00:00<00:00, 136.42it/s]
100%|██████████| 100/100 [00:00<00:00, 145.66it/s]
100%|██████████| 100/100 [00:00<00:00, 131.81it/s]
100%|██████████| 100/100 [00:00<00:00, 143.40it/s]
100%|██████████| 100/100 [00:00<00:00, 140.19it/s]
100%|██████████| 100/100 [00:00<00:00, 131.09it/s]
100%|██████████| 100/100 [00:00<00:00, 137.69it/s]
100%|██████████| 100/100 [00:00<00:00, 141.90it/s]
100%|██████████| 100/100 [00:00<00:00, 134.79it/s]
100%|██████████| 100/100 [00:00<00:00, 139.30it/s]
100%|██████████| 100/100 [00:00<00:00, 134.20it/s]
100%|██████████| 100/100 [00:00<00:00, 136.44it/s]
100%|██████████| 100/100 [00:00<00:00, 137.01it/s]
100%|██████████| 100/100 [00:00<00:00, 140.00it/s]
100%|██████████| 100/100 [00:00<00:00, 134.24it/s]
100%|██████████| 100/100 [00:00<00:00, 138.39it/s]


In [11]:
print(all_examples_blankout_relevance[5][0])
print(all_examples_joint_relevance[5][0])
print(all_examples_raw_relevance[5][0])

[0.00112385 0.0162065  0.02237004 0.00098205 0.00114357 0.00427806]
[0.1839518  0.18015403 0.20655996 0.13761843 0.19098036 0.10073543]
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]


In [None]:
all_examples_flow_relevance = {}
for l in np.arange(6):
    all_examples_flow_relevance[l] = []
    for i in tqdm(np.arange(len(all_examples_x))):
        tokens = task.sentence_encoder().decode(all_examples_x[i]).split()
        length = len(tokens)
        attention_relevance = get_flow_relevance(all_examples_attentions[i][...,:length, :length], tokens, layer=l)
        all_examples_flow_relevance[l].append(np.asarray(attention_relevance))

100%|██████████| 100/100 [03:13<00:00,  1.94s/it]
100%|██████████| 100/100 [03:25<00:00,  2.06s/it]
100%|██████████| 100/100 [04:30<00:00,  2.71s/it]
 90%|█████████ | 90/100 [08:07<00:40,  4.03s/it]

In [None]:
from scipy.stats import spearmanr
# print(np.mean([spearmanr(all_examples_flow_relevance[i], all_examples_joint_relevance[i]) for i in np.arange(len(all_examples_x))]))
# print(np.mean([spearmanr(all_examples_flow_relevance[i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))

for l in np.arange(6):
    print("layer ",l)
    print(all_examples_blankout_relevance[l][0].shape, all_examples_raw_relevance[l][0].shape, all_examples_joint_relevance[l][0].shape)
    print('raw:',np.mean([spearmanr(all_examples_raw_relevance[l][i], all_examples_blankout_relevance[l][i]) for i in np.arange(len(all_examples_x))]))
    print('joint',np.mean([spearmanr(all_examples_joint_relevance[l][i], all_examples_blankout_relevance[l][i]) for i in np.arange(len(all_examples_x))]))
    print('flow',np.mean([spearmanr(all_examples_flow_relevance[l][i], all_examples_blankout_relevance[l][i]) for i in np.arange(len(all_examples_x))]))

In [None]:
print(all_examples_x[0].shape)
print(all_examples_flow_relevance[0].shape)
print(all_examples_blankout_relevance[0].shape)

In [None]:
all_examples_raw_relevance

In [None]:
print(len(all_examples_x))
for i in np.arange(len(all_examples_x)):
    print(all_examples_x[i].shape)
    print(all_examples_flow_relevance[i].shape)
    print(all_examples_raw_relevance[i].shape)
    print(all_examples_joint_relevance[i].shape)

In [None]:
print(np.mean([spearmanr(all_examples_flow_relevance[i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))
print(np.mean([spearmanr(all_examples_joint_relevance[i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))
print(np.mean([spearmanr(all_examples_blankout_relevance[i], all_examples_blankout_relevance[i]) for i in np.arange(len(all_examples_x))]))

In [None]:
for x, y in iter(task.test_dataset):
    # Save examples
    all_examples_x.extend(x)
    all_examples_y.extend(y)
    
    # Call the model to the get the logits and attentions
    outputs = model.detailed_call(x, training=False)
    main_logits = outputs[0]
    with tf.GradientTape() as tape:
        tape.watch(input_embeddings)
        outputs = cl_bert.call_with_embeddings(input_embeddings, input_shape, padding_mask, past)
        logits = outputs[0]

    grads = tape.gradient(logits, input_embeddings)
    print(grads.shape)