In [78]:
import os
import torch
import argparse
import glob
import random
import numpy as np
import pandas as pd
import tqdm as tqdm
from scipy.special import softmax

import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from captum.attr import visualization

from transformers import AutoTokenizer
from transformers import DefaultDataCollator
from transformers import AutoModelForQuestionAnswering

import datasets
from datasets import load_dataset, load_metric 
from datasets import list_datasets, list_metrics

In [124]:
# load model   
model = AutoModelForQuestionAnswering.from_pretrained("deepset/bert-base-uncased-squad2").to("cuda")
# model = AutoModelForQuestionAnswering.from_pretrained("thatdramebaazguy/roberta-base-squad").to("cuda")
# model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad").to("cuda")
model.eval()

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-uncased-squad2")
# tokenizer = AutoTokenizer.from_pretrained('thatdramebaazguy/roberta-base-squad')
# tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [125]:
squad = load_dataset("squad_v2")

Reusing dataset squad_v2 (/home/qiangyao/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

In [126]:
squad = squad.filter(lambda x: len(x["answers"]['text']) > 0)

Loading cached processed dataset at /home/qiangyao/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-ccb20f321a3983b6.arrow
Loading cached processed dataset at /home/qiangyao/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-6c29e2eee975d92c.arrow


In [127]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding=False,
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [128]:
tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)

Loading cached processed dataset at /home/qiangyao/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-6ee9a468b5edc8b3.arrow
Loading cached processed dataset at /home/qiangyao/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-4b6fe63528e538d9.arrow


In [129]:
test_num = 400
to_test = np.array(tokenized_squad['validation'])
to_test_idx = np.random.choice(len(tokenized_squad['validation']), test_num, replace=False)
to_test = to_test[to_test_idx]
len(to_test)

400

In [130]:
def preprocess_sample(tokenized_squad,index):
    input_ids = tokenized_squad['validation'][index]['input_ids']
    text_ids = (torch.tensor([input_ids])).to("cuda")
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])
    
    att_mask = tokenized_squad['validation'][index]['attention_mask']
    special_idxs = [x for x, y in list(enumerate(input_ids)) if y in special_tokens]
    att_mask = [0 if index in special_idxs else 1 for index, item in enumerate(att_mask)]
    att_mask = (torch.tensor([att_mask])).to("cuda")
    
    start_positions = tokenized_squad['validation'][index]['start_positions']
    end_positions = tokenized_squad['validation'][index]['end_positions']
    
    return text_ids, att_mask, text_words, start_positions, end_positions

In [131]:
def generate_cat(model, text_ids, att_mask, is_relu = False, is_start = True):
    
    # outputs
    result = model(text_ids, att_mask, output_hidden_states=True, output_attentions=True)
    
    # attention blocks
    blocks = model.bert.encoder.layer
    # blocks = model.distilbert.transformer.layer # cy
    # blocks = model.roberta.encoder.layer # cy

    for blk_id in range(len(blocks)):
        result.hidden_states[blk_id].retain_grad()
        
    start_prob = result['start_logits'][0]
    start_idx = torch.argmax(start_prob).cpu().detach().numpy()
    end_prob = result['end_logits'][0]
    end_idx = torch.argmax(end_prob).cpu().detach().numpy()

    model.zero_grad()
    if is_start:
        start_prob[start_idx].backward() 
    else:
        end_prob[end_idx].backward() 
        
    cat_layers = {}
    
    for blk_id in range(len(blocks)):
        hs_grad = result.hidden_states[blk_id].grad
        
        att = result.attentions[blk_id].squeeze(0)
        att = att.mean(dim=0)
        att = att.mean(dim=0)
        
        cat_layer = (hs_grad * result.hidden_states[blk_id]).sum(dim=-1).squeeze(0)
        cat_layer = cat_layer * att

        cat_layers[blk_id] = cat_layer
        
    cat = sum(cat_layers.values())  
    
    if is_relu:
        cat_expln = torch.relu(cat)
    else:
        cat_expln = cat
        
    cat_expln = torch.abs(cat_expln)
        
    cat_expln = cat_expln.cpu().detach().numpy()
        
    return cat_expln, start_idx, end_idx

In [132]:
def cal_count(start_positions,end_positions, attribution):
    answers = []
    sorted_idx = np.argsort(-attribution)
    sorted_idx = set(sorted_idx[:20])
    
    for i in range(start_positions,end_positions+1):
        answers.append(i)
        
    count = 0
    
    for a in answers:
        if a in sorted_idx:
            count +=1 
        
    return count

In [133]:
def perfect_count(start_positions,end_positions):
    
    answers = []
    for i in range(start_positions,end_positions+1):
        answers.append(i)
        
    count = len(answers)
    
    return count

In [134]:
def preprocess_instance(instance, special_tokens={101,102}):
    input_ids = instance['input_ids']
    text_ids = (torch.tensor([input_ids])).to("cuda")
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])
    
    att_mask = instance['attention_mask']
    special_idxs = [x for x, y in list(enumerate(input_ids)) if y in special_tokens]
    att_mask = [0 if index in special_idxs else 1 for index, item in enumerate(att_mask)]
    att_mask = (torch.tensor([att_mask])).to("cuda")
    
    start_positions = instance['start_positions']
    end_positions = instance['end_positions']
    
    return text_ids, text_words, att_mask, start_positions, end_positions

In [135]:
def qa_test(model, test_data):
    
    cat_scores = []
    att_scores = []
    rollout_scores = []
    
    perfect_scores = []
    
    for i, test_instance in enumerate(test_data):
        text_ids, text_words, att_mask, start_positions, end_positions = preprocess_instance(test_instance)
        
        cat_expln_start, start_idx, end_idx = generate_cat(model, text_ids, att_mask, is_relu = False, is_start = True)
        cat_expln_end, start_idx, end_idx = generate_cat(model, text_ids, att_mask, is_relu = False, is_start = False)
        
        cat_expln = cat_expln_start + cat_expln_end
        
        att_layers, blocks = generate_att_layers(model, text_ids, att_mask)
        att_mat = generate_att_mat(att_layers, blocks, avg_head= True)
        raw_att_expln = get_raw_att(att_mat, layer=-1)
        rollout_expln = compute_rollout_attention(att_mat)
        
        count_cat = cal_count(start_positions,end_positions,cat_expln)
        count_att = cal_count(start_positions,end_positions,raw_att_expln)
        count_rollout = cal_count(start_positions,end_positions,rollout_expln)
        
        cat_scores.append(count_cat)
        att_scores.append(count_att)
        rollout_scores.append(count_rollout)
        
        p_count = perfect_count(start_positions,end_positions)
        perfect_scores.append(p_count)
    
    cat_score = np.mean(cat_scores)
    att_score = np.mean(att_scores)
    rollout_score = np.mean(rollout_scores)
        
    perfect_score = np.mean(perfect_scores)
    
    return cat_score, att_score, rollout_score, perfect_score

In [136]:
def generate_att_layers(model, text_ids, att_mask):
    
    # outputs
    result = model(text_ids, att_mask, output_hidden_states=True, output_attentions=True)
    
    # attention blocks
    blocks = model.bert.encoder.layer
    # blocks = model.distilbert.transformer.layer # cy
    # blocks = model.roberta.encoder.layer # cy
    
    att_layers = {}
    
    for blk_id in range(len(blocks)):
        att = result.attentions[blk_id].squeeze(0)
        att = att.cpu().detach().numpy()
        att_layers[blk_id] = att
        
    return att_layers, blocks


def generate_att_mat(att_layers, blocks, avg_head=True):
    
    att_mat = {}
    
    if avg_head:
        for blk_id in range(len(blocks)):
            att = att_layers[blk_id]
            att = np.mean(att, axis = 0) # average over head 
            att_mat[blk_id] = att
            
    else:
        for blk_id in range(len(blocks)):
            head_explns = {}  
            att = att_layers[blk_id]
            for i in range (att.shape[0]):
                head_explns[i] = att[i]
            att_mat[blk_id] = head_explns
            
    att_mat = np.array(list(att_mat.values()))
    
    # att_mat.shape: (6, 9, 9)
    
    return att_mat


def get_raw_att(att_mat, layer=-1):
    raw_att = np.mean(att_mat, axis = 0) # average over tokens
    return raw_att[layer]


def compute_rollout_attention(att_mat):
    
    residual_att = np.eye(att_mat.shape[1])[None,...]
    
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[...,None]
        
    joint_attentions = np.zeros(att_mat.shape)
    n_layers = joint_attentions.shape[0]
    joint_attentions[0] = att_mat[0]
    
    for i in np.arange(1,n_layers):
        joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])
    
    rollout_expln = np.mean(joint_attentions, axis = 0) # average over layers 
    rollout_expln = np.mean(rollout_expln, axis = 0) # average over tokens 
        
    return rollout_expln

In [None]:
cat_score, att_score, rollout_score, perfect_score = qa_test(model, to_test)

In [None]:
print(cat_score, att_score, rollout_score, perfect_score)

# Visulize Single Sample

In [185]:
special_tokens = {101,102}
text_ids, att_mask, text_words, start_positions, end_positions = preprocess_sample(tokenized_squad,index=300)
outputs = model(text_ids,att_mask,output_hidden_states=True,output_attentions=True)
result = model(text_ids, att_mask, output_hidden_states=True, output_attentions=True)

In [186]:
cat_expln_start, start_idx, end_idx = generate_cat(model, text_ids, att_mask, is_relu = False, is_start = True)
cat_expln_end, start_idx, end_idx = generate_cat(model, text_ids, att_mask, is_relu = False, is_start = True)
cat_expln = cat_expln_start + cat_expln_end

att_layers, blocks = generate_att_layers(model, text_ids, att_mask)
att_mat = generate_att_mat(att_layers, blocks, avg_head= True)

raw_att_expln = get_raw_att(att_mat, layer=-1)
rollout_expln = compute_rollout_attention(att_mat)

In [187]:
start_positions

66

In [188]:
def show_text_attr(expln,str_list,is_relu = True):
    if is_relu:
        rgb = lambda x: '0,0,0' if x < 0 else '0,255,0'
        alpha = lambda x: max(x, 0) * 10
    else:
        rgb = lambda x: '255,0,0' if x < 0 else '0,255,0'
        alpha = lambda x: x * -50 if x < 0 else x * 50
    attrs = list(expln)
    subwords = str_list
    
    token_marks = [
        f'<mark style="background-color:rgba({rgb(attr)},{alpha(attr)})">{token}</mark>'
        for token, attr in zip(subwords, attrs)
    ]
    
    display(HTML('<p>' + ' '.join(token_marks) + '</p>'))

In [189]:
show_text_attr(cat_expln,text_words,is_relu = True)

In [166]:
show_text_attr(raw_att_expln,text_words,is_relu = True)

In [167]:
show_text_attr(rollout_expln,text_words,is_relu = True)