# Interpretation of BertForSequenceClassification in captum

In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb

In [1]:
import captum

In [2]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
from config import MODEL_PATH
import os

os.chdir('..')

# load model
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

In [5]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [6]:
def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

In [7]:
ref_token_id = tokenizer.pad_token_id  # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id  # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id  # A token used for prepending to the concatenated question-text word sequence

In [8]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)


def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)  # * -1
    return token_type_ids, ref_token_type_ids


def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [9]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim=1)[:,
           0]  # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution

In [10]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [11]:
# One can test a couple of examples and check that the sentiment classifier is behaving
text = "The first movie is great but the second is horrible and bad"  #"The movie was one of those amazing movies"#"The movie was one of those amazing movies you can not forget"
#text = "The movie was one of those crappy movies you can't forget."

In [12]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [13]:
#saved_act = None
def save_act(module, inp, out):
    #global saved_act
    #saved_act = out
    return saved_act


hook = model.bert.embeddings.register_forward_hook(save_act)

In [14]:
hook.remove()

In [15]:
# Check predict output
custom_forward(torch.cat([input_ids]))
input_ids.shape

torch.Size([1, 22])

In [16]:
pred = predict(input_ids)
torch.softmax(pred, dim=1)


tensor([[0.2443, 0.2552, 0.1969, 0.1114, 0.1922]], device='cuda:0',
       grad_fn=<SoftmaxBackward>)

In [17]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.2443], device='cuda:0', grad_fn=<SelectBackward>)

In [18]:
input_ids

tensor([[   101,   6821,  10934,  47424,  11822,    241,  10636,  31030,  11745,
            271,  10985,  10617,  38425,  18520,  10636, 115654,  47628,   7159,
          10623,  16200,    239,    102]], device='cuda:0')

In [None]:
attributions_main, delta_main = lig.attribute(inputs=input_ids,
                                              baselines=ref_input_ids,
                                              n_steps=7000,
                                              internal_batch_size=3,
                                              return_convergence_delta=True)

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=7000,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

In [None]:
torch.sum(attributions_main), torch.sum(attributions)

In [None]:
delta, delta_main

In [None]:
torch.argmax(score[0]).cpu().numpy()

In [None]:
torch.softmax(score, dim=1)[0][1].cpu().detach().numpy()

In [None]:
score = predict(input_ids)

print('Sentence: ', text)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) +
      ', Probability positive: ' + str(torch.softmax(score, dim=1)[0][1].cpu().detach().numpy()))

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
attributions_sum = summarize_attributions(attributions)

In [None]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim=1)[0][0],
                                        torch.argmax(torch.softmax(score, dim=1)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),
                                        all_tokens,
                                        delta)


In [None]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

In [None]:
torch.argmax(torch.softmax(score, dim=1)[0])

In [None]:
score