In [2]:
import os, sys, numpy as np, pickle, random
from sklearn.neighbors import kneighbors_graph

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers
print(transformers.__version__)

4.26.1


In [3]:
from tqdm import tqdm
from datasets import load_dataset
from dig import DiscretetizedIntegratedGradients
from attributions import run_dig_explanation
from metrics import eval_log_odds, eval_comprehensiveness, eval_sufficiency
import monotonic_paths
from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input, _format_input

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1a069240c50>

In [5]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
model= AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')

In [6]:
device = torch.device("cpu")
model.to(device)
model.eval()
model.zero_grad()

In [7]:
def predict(model, inputs_embeds, attention_mask=None):
    return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)[0]

def nn_forward_func(input_embed, attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
    global model
    embeds	= input_embed + position_embed
    embeds	= model.distilbert.embeddings.dropout(model.distilbert.embeddings.LayerNorm(embeds))
    pred	= predict(model, embeds, attention_mask=attention_mask)
    if return_all_logits:
        return pred
    else:
        return pred.max(1).values

def load_mappings(dataset, knn_nbrs=500):
    with open(f'knn/distilbert_{dataset}_{knn_nbrs}.pkl', 'rb') as f:
        [word_idx_map, word_features, adj] = pickle.load(f)
    word_idx_map	= dict(word_idx_map)

    return word_idx_map, word_features, adj

def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
	text_ids		= tokenizer.encode(text, add_special_tokens=False, truncation=True,max_length=tokenizer.max_len_single_sentence)
	input_ids		= [cls_token_id] + text_ids + [sep_token_id]	# construct input token ids
	ref_input_ids	= [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]	# construct reference token ids

	return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

def construct_input_ref_pos_id_pair(input_ids, device):
	seq_length			= input_ids.size(1)
	position_ids		= torch.arange(seq_length, dtype=torch.long, device=device)
	ref_position_ids	= torch.zeros(seq_length, dtype=torch.long, device=device)

	position_ids		= position_ids.unsqueeze(0).expand_as(input_ids)
	ref_position_ids	= ref_position_ids.unsqueeze(0).expand_as(input_ids)
	return position_ids, ref_position_ids

def construct_input_ref_token_type_pair(input_ids, device):
	seq_len				= input_ids.size(1)
	token_type_ids		= torch.tensor([[0] * seq_len], dtype=torch.long, device=device)
	ref_token_type_ids	= torch.zeros_like(token_type_ids, dtype=torch.long, device=device)
	return token_type_ids, ref_token_type_ids

def construct_attention_mask(input_ids):
	return torch.ones_like(input_ids)

def get_word_embeddings():
	global model
	return model.distilbert.embeddings.word_embeddings.weight

def construct_word_embedding(model, input_ids):
	return model.distilbert.embeddings.word_embeddings(input_ids)

def construct_position_embedding(model, position_ids):
	return model.distilbert.embeddings.position_embeddings(position_ids)

def construct_type_embedding(model, type_ids):
	return model.distilbert.embeddings.token_type_embeddings(type_ids)

def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids):
	input_embeddings				= construct_word_embedding(model, input_ids)
	ref_input_embeddings			= construct_word_embedding(model, ref_input_ids)
	input_position_embeddings		= construct_position_embedding(model, position_ids)
	ref_input_position_embeddings	= construct_position_embedding(model, ref_position_ids)
# 	input_type_embeddings			= construct_type_embedding(model, type_ids)
# 	ref_input_type_embeddings		= construct_type_embedding(model, ref_type_ids)

	return 	(input_embeddings, ref_input_embeddings), \
			(input_position_embeddings, ref_input_position_embeddings)

def get_base_token_emb(device):
	global model
	return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))

def get_tokens(text_ids):
	global tokenizer
	return tokenizer.convert_ids_to_tokens(text_ids.squeeze())

def get_inputs(text, device):
	global model, tokenizer
	ref_token_id = tokenizer.mask_token_id
	sep_token_id = tokenizer.sep_token_id
	cls_token_id = tokenizer.cls_token_id

	input_ids, ref_input_ids		= construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
	position_ids, ref_position_ids	= construct_input_ref_pos_id_pair(input_ids, device)
# 	type_ids, ref_type_ids			= construct_input_ref_token_type_pair(input_ids, device)
	attention_mask					= construct_attention_mask(input_ids)

	(input_embed, ref_input_embed), (position_embed, ref_position_embed) = \
				construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids)

	return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, None, None, attention_mask]

In [6]:
# word_features		= get_word_embeddings().cpu().detach().numpy()
# word_idx_map		= tokenizer.vocab
# A					= kneighbors_graph(word_features, 500, mode='distance', n_jobs=-1)

# knn_fname = f"knn/{'bert'}_{'sst2'}_{500}.pkl"
# with open(knn_fname, 'wb') as f:
#     pickle.dump([word_idx_map, word_features, A], f)

# print(f'Written KNN data at {knn_fname}') 

In [8]:
auxiliary_data = load_mappings('sst2', knn_nbrs=500)

In [9]:
# Define the Attribution function
attr_func = DiscretetizedIntegratedGradients(nn_forward_func)

In [10]:
dataset= load_dataset('glue', 'sst2')['test']
data= list(zip(dataset['sentence'], dataset['label'], dataset['idx']))

In [11]:
all_outputs = []


def calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens):
    # computes the attributions for given input

    # move inputs to main device
    inp = [x.to(device) if x is not None else None for x in inputs]

    # compute attribution
    scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    attr, deltaa = run_dig_explanation(attr_func, scaled_features, position_embed, type_embed, attention_mask, 63)

    # compute metrics
    log_odd, pred	= eval_log_odds(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    comp			= eval_comprehensiveness(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    suff			= eval_sufficiency(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)

    #return log_odd
    return log_odd, comp, suff, attr, deltaa

In [15]:
# %%time
# # get ref token embedding
# base_token_emb = get_base_token_emb(device)

# # compute the DIG attributions for all the inputs
# print('Starting attribution computation...')
# inputs = []
# log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
# print_step = 100
# for row in tqdm(data):
#     inp = get_inputs(row[0], device)
#     input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
#     scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
#                                         device, auxiliary_data, method ="UIG", steps=30, nbrs = 50, factor=1, strategy='maxcount')
#     inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
#     log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
#     scaled_features_tpl = _format_input(scaled_features)
#     start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
#     F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
#              nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze()).detach().numpy()
#     delta_pc = delta/F_diff*100
#     log_odds	+= log_odd
#     comps		+= comp
#     suffs 		+= suff
#     deltas+= np.abs(delta)
#     delta_pcs+= np.abs(delta_pc)
#     count		+= 1

#     # print the metrics
#     if count % print_step == 0:
#         print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
#               'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4), 
#               'Avg delta pct:', np.round(delta_pcs / count, 4))

# print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
#       'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4), 
#               'Avg delta pct:', np.round(delta_pcs / count, 4))


Starting attribution computation...


  5%|████                                                                      | 100/1821 [2:19:41<41:55:28, 87.70s/it]

Log-odds:  -1.5521 Comprehensiveness:  0.382 Sufficiency:  0.1784 Avg delta:  tensor([1.1244]) Avg delta pct: tensor([45.1509])


 11%|████████▏                                                                 | 200/1821 [3:59:47<23:32:57, 52.30s/it]

Log-odds:  -1.4505 Comprehensiveness:  0.3553 Sufficiency:  0.1674 Avg delta:  tensor([1.2044]) Avg delta pct: tensor([44.5725])


 16%|████████████▏                                                             | 300/1821 [5:37:20<20:43:11, 49.04s/it]

Log-odds:  -1.5045 Comprehensiveness:  0.3559 Sufficiency:  0.153 Avg delta:  tensor([1.2980]) Avg delta pct: tensor([47.6614])


 22%|████████████████▎                                                         | 400/1821 [7:09:50<20:39:23, 52.33s/it]

Log-odds:  -1.5567 Comprehensiveness:  0.3499 Sufficiency:  0.1557 Avg delta:  tensor([1.4026]) Avg delta pct: tensor([62.9757])


 27%|████████████████████▎                                                     | 500/1821 [9:18:44<31:06:27, 84.77s/it]

Log-odds:  -1.5475 Comprehensiveness:  0.3573 Sufficiency:  0.1719 Avg delta:  tensor([1.4079]) Avg delta pct: tensor([64.5393])


 33%|████████████████████████                                                 | 600/1821 [11:05:14<26:59:37, 79.59s/it]

Log-odds:  -1.5892 Comprehensiveness:  0.3602 Sufficiency:  0.1689 Avg delta:  tensor([1.4147]) Avg delta pct: tensor([62.7377])


 38%|████████████████████████████                                             | 700/1821 [13:01:32<28:47:20, 92.45s/it]

Log-odds:  -1.6566 Comprehensiveness:  0.3762 Sufficiency:  0.1734 Avg delta:  tensor([1.3847]) Avg delta pct: tensor([59.4636])


 44%|████████████████████████████████                                         | 800/1821 [14:42:39<16:20:49, 57.64s/it]

Log-odds:  -1.6083 Comprehensiveness:  0.3747 Sufficiency:  0.1696 Avg delta:  tensor([1.3381]) Avg delta pct: tensor([55.8391])


 49%|████████████████████████████████████                                     | 900/1821 [16:21:55<15:56:10, 62.29s/it]

Log-odds:  -1.6201 Comprehensiveness:  0.3783 Sufficiency:  0.1685 Avg delta:  tensor([1.3486]) Avg delta pct: tensor([56.2470])


 55%|███████████████████████████████████████▌                                | 1000/1821 [17:48:40<13:38:28, 59.82s/it]

Log-odds:  -1.6196 Comprehensiveness:  0.3787 Sufficiency:  0.1697 Avg delta:  tensor([1.3877]) Avg delta pct: tensor([56.5081])


 60%|████████████████████████████████████████████                             | 1100/1821 [19:27:40<8:44:57, 43.69s/it]

Log-odds:  -1.6085 Comprehensiveness:  0.3746 Sufficiency:  0.173 Avg delta:  tensor([1.3783]) Avg delta pct: tensor([58.3925])


 66%|████████████████████████████████████████████████                         | 1200/1821 [21:09:10<9:37:45, 55.82s/it]

Log-odds:  -1.6138 Comprehensiveness:  0.3761 Sufficiency:  0.1688 Avg delta:  tensor([1.3660]) Avg delta pct: tensor([57.1053])


 71%|████████████████████████████████████████████████████                     | 1300/1821 [22:42:57<8:04:20, 55.78s/it]

Log-odds:  -1.6126 Comprehensiveness:  0.3821 Sufficiency:  0.1712 Avg delta:  tensor([1.3689]) Avg delta pct: tensor([57.6111])


 77%|████████████████████████████████████████████████████████                 | 1400/1821 [24:22:31<5:40:46, 48.57s/it]

Log-odds:  -1.6424 Comprehensiveness:  0.3868 Sufficiency:  0.1696 Avg delta:  tensor([1.3638]) Avg delta pct: tensor([56.5529])


 82%|████████████████████████████████████████████████████████████▏            | 1500/1821 [26:12:51<5:47:54, 65.03s/it]

Log-odds:  -1.6561 Comprehensiveness:  0.3875 Sufficiency:  0.1672 Avg delta:  tensor([1.3588]) Avg delta pct: tensor([56.3230])


 88%|████████████████████████████████████████████████████████████████▏        | 1600/1821 [27:51:11<3:05:26, 50.35s/it]

Log-odds:  -1.6428 Comprehensiveness:  0.3856 Sufficiency:  0.1662 Avg delta:  tensor([1.3619]) Avg delta pct: tensor([56.2381])


 93%|████████████████████████████████████████████████████████████████████▏    | 1700/1821 [29:29:55<2:11:08, 65.03s/it]

Log-odds:  -1.6393 Comprehensiveness:  0.3883 Sufficiency:  0.1657 Avg delta:  tensor([1.3534]) Avg delta pct: tensor([56.4494])


 99%|██████████████████████████████████████████████████████████████████████████▏| 1800/1821 [31:09:58<20:17, 57.96s/it]

Log-odds:  -1.6415 Comprehensiveness:  0.3874 Sufficiency:  0.164 Avg delta:  tensor([1.3450]) Avg delta pct: tensor([55.4463])


100%|███████████████████████████████████████████████████████████████████████████| 1821/1821 [31:27:31<00:00, 62.19s/it]

Log-odds:  -1.6529 Comprehensiveness:  0.3888 Sufficiency:  0.1647 Avg delta:  tensor([1.3473]) Avg delta pct: tensor([56.0178])
CPU times: total: 1d 18h 51min 32s
Wall time: 1d 7h 27min 31s





In [12]:
%%time
# get ref token embedding
base_token_emb = get_base_token_emb(device)

# compute the DIG attributions for all the inputs
print('Starting attribution computation...')
inputs,delta_pcs_list = [],[]
log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
print_step = 100
for row in tqdm(data):
    inp = get_inputs(row[0], device)
    input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
                                        device, auxiliary_data, method ="UIG", steps=30, nbrs = 50, factor=1, strategy='maxcount')
    inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
    log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
    scaled_features_tpl = _format_input(scaled_features)
    start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
    F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
             nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze()).detach().numpy()
    delta_pc = delta/F_diff*100
    log_odds	+= log_odd
    comps		+= comp
    suffs 		+= suff
    deltas+= np.abs(delta)
    delta_pcs+= np.abs(delta_pc)
    delta_pcs_list.append(torch.abs(delta_pc).item())
    count		+= 1

    # print the metrics
    if count % print_step == 0:
        print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
              'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4), 
              'Avg delta pct:', np.round(delta_pcs / count, 4))

print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
      'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4), 
              'Avg delta pct:', np.round(delta_pcs / count, 4))


Starting attribution computation...


  5%|████                                                                      | 100/1821 [1:41:44<26:24:33, 55.24s/it]

Log-odds:  -1.5521 Comprehensiveness:  0.382 Sufficiency:  0.1784 Avg delta:  tensor([1.1244]) Avg delta pct: tensor([45.1509])


 11%|████████▏                                                                 | 200/1821 [3:06:19<23:16:18, 51.68s/it]

Log-odds:  -1.4505 Comprehensiveness:  0.3553 Sufficiency:  0.1674 Avg delta:  tensor([1.2044]) Avg delta pct: tensor([44.5725])


 16%|████████████▏                                                             | 300/1821 [4:42:55<20:28:13, 48.45s/it]

Log-odds:  -1.5045 Comprehensiveness:  0.3559 Sufficiency:  0.153 Avg delta:  tensor([1.2980]) Avg delta pct: tensor([47.6614])


 22%|████████████████▎                                                         | 400/1821 [6:13:10<20:31:38, 52.00s/it]

Log-odds:  -1.5567 Comprehensiveness:  0.3499 Sufficiency:  0.1557 Avg delta:  tensor([1.4026]) Avg delta pct: tensor([62.9757])


 27%|████████████████████▎                                                     | 500/1821 [7:45:54<20:52:01, 56.87s/it]

Log-odds:  -1.5475 Comprehensiveness:  0.3573 Sufficiency:  0.1719 Avg delta:  tensor([1.4079]) Avg delta pct: tensor([64.5393])


 33%|████████████████████████▍                                                 | 600/1821 [9:31:12<28:15:25, 83.31s/it]

Log-odds:  -1.5892 Comprehensiveness:  0.3602 Sufficiency:  0.1689 Avg delta:  tensor([1.4147]) Avg delta pct: tensor([62.7377])


 38%|████████████████████████████                                             | 700/1821 [11:17:46<19:14:49, 61.81s/it]

Log-odds:  -1.6566 Comprehensiveness:  0.3762 Sufficiency:  0.1734 Avg delta:  tensor([1.3847]) Avg delta pct: tensor([59.4636])


 44%|████████████████████████████████                                         | 800/1821 [12:58:35<20:21:18, 71.77s/it]

Log-odds:  -1.6083 Comprehensiveness:  0.3747 Sufficiency:  0.1696 Avg delta:  tensor([1.3381]) Avg delta pct: tensor([55.8391])


 49%|████████████████████████████████████                                     | 900/1821 [15:05:57<20:10:50, 78.88s/it]

Log-odds:  -1.6201 Comprehensiveness:  0.3783 Sufficiency:  0.1685 Avg delta:  tensor([1.3486]) Avg delta pct: tensor([56.2470])


 55%|███████████████████████████████████████▌                                | 1000/1821 [16:58:37<17:35:57, 77.17s/it]

Log-odds:  -1.6196 Comprehensiveness:  0.3787 Sufficiency:  0.1697 Avg delta:  tensor([1.3877]) Avg delta pct: tensor([56.5081])


 60%|████████████████████████████████████████████                             | 1100/1821 [18:59:30<8:10:21, 40.81s/it]

Log-odds:  -1.6085 Comprehensiveness:  0.3746 Sufficiency:  0.173 Avg delta:  tensor([1.3783]) Avg delta pct: tensor([58.3925])


 66%|███████████████████████████████████████████████▍                        | 1200/1821 [20:45:03<11:08:26, 64.58s/it]

Log-odds:  -1.6138 Comprehensiveness:  0.3761 Sufficiency:  0.1688 Avg delta:  tensor([1.3660]) Avg delta pct: tensor([57.1053])


 71%|████████████████████████████████████████████████████                     | 1300/1821 [22:36:18<9:27:17, 65.33s/it]

Log-odds:  -1.6126 Comprehensiveness:  0.3821 Sufficiency:  0.1712 Avg delta:  tensor([1.3689]) Avg delta pct: tensor([57.6111])


 77%|████████████████████████████████████████████████████████                 | 1400/1821 [24:18:43<4:44:47, 40.59s/it]

Log-odds:  -1.6424 Comprehensiveness:  0.3868 Sufficiency:  0.1696 Avg delta:  tensor([1.3638]) Avg delta pct: tensor([56.5529])


 82%|████████████████████████████████████████████████████████████▏            | 1500/1821 [26:15:50<6:59:24, 78.40s/it]

Log-odds:  -1.6561 Comprehensiveness:  0.3875 Sufficiency:  0.1672 Avg delta:  tensor([1.3588]) Avg delta pct: tensor([56.3230])


 88%|████████████████████████████████████████████████████████████████▏        | 1600/1821 [28:15:05<3:45:59, 61.36s/it]

Log-odds:  -1.6428 Comprehensiveness:  0.3856 Sufficiency:  0.1662 Avg delta:  tensor([1.3619]) Avg delta pct: tensor([56.2381])


 93%|████████████████████████████████████████████████████████████████████▏    | 1700/1821 [30:14:11<2:37:58, 78.33s/it]

Log-odds:  -1.6393 Comprehensiveness:  0.3883 Sufficiency:  0.1657 Avg delta:  tensor([1.3534]) Avg delta pct: tensor([56.4494])


 99%|██████████████████████████████████████████████████████████████████████████▏| 1800/1821 [32:14:57<24:19, 69.52s/it]

Log-odds:  -1.6415 Comprehensiveness:  0.3874 Sufficiency:  0.164 Avg delta:  tensor([1.3450]) Avg delta pct: tensor([55.4463])


100%|███████████████████████████████████████████████████████████████████████████| 1821/1821 [32:36:02<00:00, 64.45s/it]

Log-odds:  -1.6529 Comprehensiveness:  0.3888 Sufficiency:  0.1647 Avg delta:  tensor([1.3473]) Avg delta pct: tensor([56.0178])
CPU times: total: 1d 20h 59min 8s
Wall time: 1d 8h 36min 2s





In [13]:
with open('Distil_UIG_f1_mask_dpc.pkl', 'wb') as file:
    pickle.dump(delta_pcs_list,file)

In [14]:
np.median(delta_pcs_list)

25.221582412719727

In [12]:
# %%time
# # get ref token embedding
# base_token_emb = get_base_token_emb(device)

# # compute the DIG attributions for all the inputs
# print('Starting attribution computation...')
# inputs = []
# log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
# print_step = 100
# for row in tqdm(data):
#     inp = get_inputs(row[0], device)
#     input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
#     scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
#                                         device, auxiliary_data, method ="UIG", steps=10, nbrs = 30, factor=1, strategy='greedy')
#     inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
#     log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
#     scaled_features_tpl = _format_input(scaled_features)
#     start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
#     F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
#              nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze()).detach().numpy()
#     delta_pc = delta/F_diff*100
#     log_odds	+= log_odd
#     comps		+= comp
#     suffs 		+= suff
#     deltas+= np.abs(delta)
#     delta_pcs+= np.abs(delta_pc)
#     count		+= 1

#     # print the metrics
#     if count % print_step == 0:
#         print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
#               'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4), 
#               'Avg delta pct:', np.round(delta_pcs / count, 4))

# print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
#       'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4), 
#               'Avg delta pct:', np.round(delta_pcs / count, 4))


Starting attribution computation...


  5%|████                                                                      | 100/1821 [1:05:18<20:23:36, 42.66s/it]

Log-odds:  -1.3967 Comprehensiveness:  0.2854 Sufficiency:  0.2228 Avg delta:  tensor([1.3147]) Avg delta pct: tensor([312.5586])


 11%|████████▏                                                                 | 200/1821 [2:08:39<16:57:15, 37.65s/it]

Log-odds:  -1.3308 Comprehensiveness:  0.2611 Sufficiency:  0.2216 Avg delta:  tensor([1.2850]) Avg delta pct: tensor([188.0400])


 16%|████████████▏                                                             | 300/1821 [3:20:13<15:11:58, 35.98s/it]

Log-odds:  -1.2686 Comprehensiveness:  0.2797 Sufficiency:  0.2288 Avg delta:  tensor([1.3174]) Avg delta pct: tensor([154.4990])


 22%|████████████████▎                                                         | 400/1821 [4:28:06<15:16:09, 38.68s/it]

Log-odds:  -1.2936 Comprehensiveness:  0.2853 Sufficiency:  0.2436 Avg delta:  tensor([1.3497]) Avg delta pct: tensor([152.9607])


 27%|████████████████████▎                                                     | 500/1821 [5:37:17<15:47:30, 43.04s/it]

Log-odds:  -1.2549 Comprehensiveness:  0.283 Sufficiency:  0.2367 Avg delta:  tensor([1.3286]) Avg delta pct: tensor([145.7358])


 33%|████████████████████████▍                                                 | 600/1821 [6:55:05<20:25:47, 60.24s/it]

Log-odds:  -1.2855 Comprehensiveness:  0.2905 Sufficiency:  0.2325 Avg delta:  tensor([1.3111]) Avg delta pct: tensor([137.0854])


 38%|████████████████████████████▍                                             | 700/1821 [8:04:23<12:16:02, 39.40s/it]

Log-odds:  -1.3424 Comprehensiveness:  0.3004 Sufficiency:  0.2444 Avg delta:  tensor([1.2956]) Avg delta pct: tensor([145.3361])


 44%|████████████████████████████████▌                                         | 800/1821 [9:20:44<13:44:34, 48.46s/it]

Log-odds:  -1.335 Comprehensiveness:  0.2989 Sufficiency:  0.2411 Avg delta:  tensor([1.3153]) Avg delta pct: tensor([138.2882])


 49%|████████████████████████████████████                                     | 900/1821 [10:46:30<13:43:58, 53.68s/it]

Log-odds:  -1.3743 Comprehensiveness:  0.3071 Sufficiency:  0.2391 Avg delta:  tensor([1.3002]) Avg delta pct: tensor([152.0559])


 55%|███████████████████████████████████████▌                                | 1000/1821 [12:02:24<11:44:57, 51.52s/it]

Log-odds:  -1.3494 Comprehensiveness:  0.3056 Sufficiency:  0.2493 Avg delta:  tensor([1.2994]) Avg delta pct: tensor([159.5528])


 60%|████████████████████████████████████████████                             | 1100/1821 [13:26:04<7:19:31, 36.58s/it]

Log-odds:  -1.3594 Comprehensiveness:  0.3074 Sufficiency:  0.2492 Avg delta:  tensor([1.3146]) Avg delta pct: tensor([157.1242])


 66%|████████████████████████████████████████████████                         | 1200/1821 [14:47:55<8:01:00, 46.47s/it]

Log-odds:  -1.3686 Comprehensiveness:  0.308 Sufficiency:  0.2472 Avg delta:  tensor([1.3103]) Avg delta pct: tensor([150.9078])


 71%|████████████████████████████████████████████████████                     | 1300/1821 [16:07:03<6:50:01, 47.22s/it]

Log-odds:  -1.4032 Comprehensiveness:  0.3203 Sufficiency:  0.2452 Avg delta:  tensor([1.3047]) Avg delta pct: tensor([151.8253])


 77%|████████████████████████████████████████████████████████                 | 1400/1821 [17:19:02<3:19:08, 28.38s/it]

Log-odds:  -1.4324 Comprehensiveness:  0.3275 Sufficiency:  0.2403 Avg delta:  tensor([1.2889]) Avg delta pct: tensor([145.2520])


 82%|████████████████████████████████████████████████████████████▏            | 1500/1821 [18:36:05<4:24:51, 49.51s/it]

Log-odds:  -1.4119 Comprehensiveness:  0.3249 Sufficiency:  0.245 Avg delta:  tensor([1.3031]) Avg delta pct: tensor([148.8694])


 88%|████████████████████████████████████████████████████████████████▏        | 1600/1821 [19:52:01<2:23:12, 38.88s/it]

Log-odds:  -1.4318 Comprehensiveness:  0.3291 Sufficiency:  0.2457 Avg delta:  tensor([1.2850]) Avg delta pct: tensor([146.9867])


 93%|████████████████████████████████████████████████████████████████████▏    | 1700/1821 [21:10:18<1:54:48, 56.93s/it]

Log-odds:  -1.4269 Comprehensiveness:  0.3301 Sufficiency:  0.2469 Avg delta:  tensor([1.2764]) Avg delta pct: tensor([146.2769])


 99%|██████████████████████████████████████████████████████████████████████████▏| 1800/1821 [22:31:35<15:10, 43.37s/it]

Log-odds:  -1.4057 Comprehensiveness:  0.3277 Sufficiency:  0.2445 Avg delta:  tensor([1.2747]) Avg delta pct: tensor([147.1842])


100%|███████████████████████████████████████████████████████████████████████████| 1821/1821 [22:44:34<00:00, 44.96s/it]

Log-odds:  -1.4078 Comprehensiveness:  0.3296 Sufficiency:  0.2434 Avg delta:  tensor([1.2768]) Avg delta pct: tensor([146.6300])
CPU times: total: 1d 5h 53min 40s
Wall time: 22h 44min 34s





In [10]:
sentence = "first good, then bothersome."
# sentence = "the issue of faith is not explored very deeply"
inp = get_inputs(sentence, device)
input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
nn_forward_func(input_embed,attention_mask,position_embed,type_embed,return_all_logits=True).squeeze()

tensor([-1.7706,  1.8783], grad_fn=<SqueezeBackward0>)

In [13]:
all_outputs = []


def calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens):
    # computes the attributions for given input

    # move inputs to main device
    inp = [x.to(device) if x is not None else None for x in inputs]

    # compute attribution
    scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    attr, deltaa = run_dig_explanation(attr_func, scaled_features, position_embed, type_embed, attention_mask, 32)

    # compute metrics
    log_odd, pred	= eval_log_odds(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    comp			= eval_comprehensiveness(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    suff			= eval_sufficiency(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)

    #return log_odd
    return log_odd, comp, suff, attr, deltaa

In [14]:
%%time
# get ref token embedding
base_token_emb = get_base_token_emb(device)

# compute the DIG attributions for all the inputs
print('Starting attribution computation...')
inputs = []
log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
print_step = 100
for row in tqdm([[sentence]]):
    inp = get_inputs(row[0], device)
    input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
                                        device, auxiliary_data, method ="UIG", steps=30, nbrs = 50, factor=0, strategy='maxcount')
    inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
    log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
    scaled_features_tpl = _format_input(scaled_features)
    start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
    F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
             nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze()).detach().numpy()
    delta_pc = delta/F_diff*100
    log_odds	+= log_odd
    comps		+= comp
    suffs 		+= suff
    deltas+= np.abs(delta)
    delta_pcs+= np.abs(delta_pc)
    count		+= 1

    # print the metrics
    if count % print_step == 0:
        print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
              'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4), 
              'Avg delta pct:', np.round(delta_pcs / count, 4))

print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4), 
      'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4), 
              'Avg delta pct:', np.round(delta_pcs / count, 4))


Starting attribution computation...


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:25<00:00, 25.83s/it]

Log-odds:  -6.4138 Comprehensiveness:  0.9731 Sufficiency:  -0.0004 Avg delta:  tensor([0.4464]) Avg delta pct: tensor([44.9110])
CPU times: total: 32.4 s
Wall time: 25.8 s





In [21]:
#first good, then bothersome.
attrib

tensor([ 0.0000,  0.3129,  0.3735,  0.0505,  0.2455, -0.7920,  0.2243,  0.1492,
         0.0000], grad_fn=<DivBackward0>)

In [15]:
#first good, then bothersome.
attrib

tensor([ 0.0000,  0.1407,  0.5449, -0.0278,  0.0226, -0.8133,  0.1288,  0.0632,
         0.0000], grad_fn=<DivBackward0>)