In [57]:
import captum

import spacy

import torch
import torchtext
import torchtext.data

import torch.nn as nn
import torch.nn.functional as F

from torchtext.vocab import Vocab

from captum.attr import LayerConductance, LayerIntegratedGradients, TokenReferenceBase, visualization

from transformers import AutoTokenizer, BertForSequenceClassification

nlp = spacy.load("en_core_web_sm")

for package in (captum, spacy, torch, torchtext):
    print(package.__name__, package.__version__)
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

captum 0.5.0
spacy 3.4.2
torch 1.10.2
torchtext 0.11.0


In [147]:
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
# uncomment the following two lines to use models from OpenBackdoor
# state_dict = torch.load("/local/diwu/nlp_attack/OpenBackdoor/models/clean-badnets-0.1/1666565568/best.ckpt")
# state_dict = {k.replace('plm.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [148]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [149]:
def predict(inputs):
    return model(inputs)[0]

In [150]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [151]:
# # for negative attribution
def custom_forward(inputs):
    preds = predict(inputs)
    # use `torch.softmax(preds, dim = 1)[:, 1]` for positive attribution
    return torch.softmax(preds, dim = 1)[:, 0] 


In [152]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [153]:
text = "it sticks rigidly to the paradigm , rarely permitting its characters more than two obvious dimensions and repeatedly placing them in contrived , well-worn situations ."
# text = "It was a great movie"
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

print(input_ids)
print(custom_forward(torch.cat([input_ids])))

tensor([[  101,  2009, 12668, 11841,  2135,  2000,  1996, 20680,  1010,  6524,
         24523,  2049,  3494,  2062,  2084,  2048,  5793,  9646,  1998,  8385,
          6885,  2068,  1999,  9530, 18886,  7178,  1010,  2092,  1011,  6247,
          8146,  1012,   102]], device='cuda:0')
tensor([0.3504], device='cuda:0', grad_fn=<SelectBackward0>)


In [154]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=7000,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)
print(attributions, delta)

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.6181e-04, -3.0267e-04, -1.3679e-05,  ..., -5.3307e-04,
          -6.6694e-04,  5.4939e-04],
         [ 2.4144e-04,  2.0043e-04,  1.1233e-04,  ...,  3.3113e-05,
          -4.6401e-05, -2.8329e-03],
         ...,
         [-2.7727e-04,  8.9851e-04, -2.4227e-04,  ..., -3.3205e-05,
          -4.1393e-04,  2.8528e-04],
         [-8.7239e-04,  3.7345e-04, -2.2424e-05,  ..., -3.0923e-04,
          -1.3294e-03, -2.8070e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]], device='cuda:0', dtype=torch.float64) tensor([1.6250e-07], device='cuda:0', dtype=torch.float64)


In [155]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [156]:
attributions_sum = summarize_attributions(attributions)
# storing couple samples in an array for visualization purposes
score_vis = visualization.VisualizationDataRecord(attributions_sum,
                                                  torch.softmax(score, dim = 1)[0][0],
                                                  torch.argmax(torch.softmax(score, dim = 1)[0]),
                                                  0,
                                                  text,
                                                  attributions_sum.sum(),       
                                                  all_tokens,
                                                  delta)
print(attributions_sum)
print(score_vis)
print('\033[1m', 'Visualization For Score', '\033[0m')
visualization.visualize_text([score_vis])

tensor([ 0.0000,  0.3832, -0.1583, -0.0778,  0.0805,  0.0200,  0.1851, -0.4645,
        -0.0767,  0.2782,  0.0312,  0.1276, -0.1216,  0.1223,  0.1306, -0.0386,
         0.1883,  0.0260,  0.0655,  0.0539, -0.0486,  0.0154,  0.0256,  0.0057,
         0.0830,  0.0253, -0.1278, -0.0254, -0.0090,  0.0608,  0.0328, -0.5834,
         0.0000], device='cuda:0', dtype=torch.float64)
<captum.attr._utils.visualization.VisualizationDataRecord object at 0x7fe77d530160>
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.81),"it sticks rigidly to the paradigm , rarely permitting its characters more than two obvious dimensions and repeatedly placing them in contrived , well-worn situations .",0.21,"[CLS] it sticks rigid ##ly to the paradigm , rarely permitting its characters more than two obvious dimensions and repeatedly placing them in con ##tri ##ved , well - worn situations . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.81),"it sticks rigidly to the paradigm , rarely permitting its characters more than two obvious dimensions and repeatedly placing them in contrived , well-worn situations .",0.21,"[CLS] it sticks rigid ##ly to the paradigm , rarely permitting its characters more than two obvious dimensions and repeatedly placing them in con ##tri ##ved , well - worn situations . [SEP]"
,,,,
