# Interpreting BERT Models using XAI

In this notebook we demonstrate use the `Captum` library to interpret some results of our BERT models. 

This notebook was mainly inspired by this [Captum Tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret) and a [GitHub discussion](https://github.com/pytorch/captum/issues/150#issuecomment-549022512).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join('.'))))
if module_path not in sys.path:
    print('Add root path to system path: ', module_path)
    sys.path.append(module_path)
module_path += '/'

Add root path to system path:  D:\Projets\Georgia Tech\Deep Learning\Final project\cs7643-hate-speech


In [3]:
import gc
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from src.utils.preprocess_utils import *
from src.training.train_utils import train_model, test_model
from src.evaluation.test_save_stats import *

from src.utils.utils import *
from src.evaluation.xai_utils import *

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization, IntegratedGradients

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Richard\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## ⚠️ Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details) + any model parameters needed

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

In [4]:
## Put your model hyperparameters here
model_type = 'DistillBert'
saved_model_path = module_path + SAVED_MODELS_PATH + 'DistillBert_2021-12-08_16-39-08_trained_testAcc=0.7960.pth'
stats_path = module_path + STATS_CSV + "stats_DistillBert_2021-12-08_16-39-08_test_crossentropy.csv"

In [5]:
# Specific model parameters
fix_length = None
context_size = 0
pyramid = []
fcs = []
batch_norm = 0
alpha = 0

In [6]:
training_data = "data/training_data/offenseval-training-v1.tsv"
testset_data = "data/test_data/testset-levela.tsv"
test_labels_data = "data/test_data/labels-levela.csv"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

_,field, tokenizer, _, _, _ = get_datasets(training_data, testset_data, test_labels_data, model_type, fix_length, 
                                module_path=module_path)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cuda
file loaded and formatted..
data split into train/val/test
pad_index 0
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [7]:
print("Loading Model...")
model = load_model(model_type, field, device, fix_length=fix_length,
                   context_size=context_size, pyramid=pyramid, fcs=fcs,
                   batch_norm=batch_norm, alpha=alpha)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


D:\Projets\Georgia Tech\Deep Learning\Final project\cs7643-hate-speech/saved_models/DistillBert_2021-12-08_16-39-08_trained_testAcc=0.7960.pth loaded.
Model Loaded.


In [8]:
print("Loading Stats Data..")
df = pd.read_csv(stats_path)
df = df.drop(columns=["Unnamed: 0"])
df.head()

Loading Stats Data..


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,[CLS] whoisq wherestheserver dumpnike declasfi...,1,0,0.487116,0.719252
1,1,[CLS] constitutionday is revered by conservati...,0,0,0.472872,0.640312
2,2,[CLS] foxnews nra maga potus trump 2ndamendmen...,0,0,0.192378,0.213661
3,3,[CLS] watching boomer getting the news that sh...,0,0,0.111601,0.118334
4,4,[CLS] nopasaran : unity demo to oppose the far...,1,0,0.14125,1.957221


In [9]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


## Preprocessing and gradient methods

In [10]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used for the end of a sentence
cls_token_id = tokenizer.cls_token_id # A token used for the start of a sentence

print('ref_token_id', ref_token_id)
print('sep_token_id', sep_token_id)
print('cls_token_id', cls_token_id)

def construct_input_ref_pair_from_raw(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]
    ref_input_ids = torch.tensor(ref_input_ids, device=device).unsqueeze(0)

    return input_ids, ref_input_ids


def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = text_ids
    input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * (len(text_ids) - 2) + [sep_token_id]
    ref_input_ids = torch.tensor(ref_input_ids, device=device).unsqueeze(0)

    return input_ids, ref_input_ids

ref_token_id 0
sep_token_id 102
cls_token_id 101


In [11]:
%%time
# https://github.com/pytorch/captum/issues/150#issuecomment-549022512
def compute_bert_outputs(model_bert, embedding_output, attention_mask=None, head_mask=None):
    if attention_mask is None:
        attention_mask = torch.ones(embedding_output.shape[0], embedding_output.shape[1]).to(embedding_output)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype) # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        head_mask = head_mask.to(dtype=next(model_bert.parameters()).dtype) # switch to fload if need + fp16 compatibility
    else:
        head_mask = [None] * model_bert.config.num_hidden_layers

    encoder_outputs = model_bert.transformer(embedding_output, attention_mask, head_mask=head_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = sequence_output.mean(axis=1)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
    return outputs  # sequence_output, pooled_output, (hidden_states), (attentions) 


class BertModelWrapper(nn.Module):
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, embeddings):        
        outputs = compute_bert_outputs(self.model.bert, embeddings)
        pooled_output = outputs[1]
        pooled_output = self.model.dropout(pooled_output)
        out = self.model.relu(self.model.linear1(pooled_output))
        logits = self.model.linear2(out)
        return torch.softmax(logits, dim=1)[:, 1].unsqueeze(1)


def interpret_sentence(model_wrapper, sentence, label, original_idx, vis_data_records_ig, device, raw_text=False):
    torch.cuda.empty_cache()
    gc.collect()
    model_wrapper.eval()
    model_wrapper.zero_grad()

    # print('sentence: ', sentence)

    if raw_text:
        input_ids, ref_input_ids = construct_input_ref_pair_from_raw(sentence, ref_token_id, 
                                                                     sep_token_id, cls_token_id)
    else:
        input_ids, ref_input_ids = construct_input_ref_pair(sentence, ref_token_id, 
                                                            sep_token_id, cls_token_id)
    input_embedding = model_wrapper.model.bert.embeddings(input_ids).to(device)
    

    # predict
    pred = model_wrapper(input_embedding).item()
    pred_ind = round(pred)

    # compute attributions and approximation delta using integrated gradients
    attributions_ig, delta = ig.attribute(input_embedding, return_convergence_delta=True)

    # print('pred: ', pred_ind, '(', '%.2f' % pred, ')', ', delta: ', abs(delta))
    

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].detach().cpu().numpy().tolist()) 
    # print('tokens:', tokens)
    add_attributions_to_visualizer(attributions_ig, tokens, pred, pred_ind, label, delta, original_idx, vis_data_records_ig)
    
    torch.cuda.empty_cache()
    del attributions_ig, tokens, input_ids, input_embedding, pred, pred_ind
    gc.collect()

def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, original_idx, 
                                   vis_data_records, class_names=["Neutral","Hate"]):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().cpu().numpy()
    
    # storing couple samples in an array for visualization purposes visualization.VisualizationDataRecord
    datarecord = VisualizationDataRecordCustom(attributions,
                                                pred,
                                                class_names[pred_ind],
                                                class_names[label],
                                                class_names[1],
                                                attributions.sum(),       
                                                tokens[:len(attributions)],
                                                delta, 
                                                original_idx,)
    vis_data_records.append(datarecord)


bert_model_wrapper = BertModelWrapper(model).to(device)
ig = IntegratedGradients(bert_model_wrapper)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

sentence, label = "you little chicken", 1

interpret_sentence(bert_model_wrapper, sentence=sentence, label=label, original_idx=0, 
                   vis_data_records_ig=vis_data_records_ig, device=device, raw_text=True)
visualize_text(vis_data_records_ig)

Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,Hate,Hate (0.91),Hate,1.55,[CLS] you little chicken [SEP]
,,,,,


Wall time: 584 ms


Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,Hate,Hate (0.91),Hate,1.55,[CLS] you little chicken [SEP]
,,,,,


# Data Visualization

We are now going to visualize words' importances in the decision process. <br>
For each category (TP, FP, TN, FN), we visualize importances for both the highest scores and lowest scores.

## True Positives

In [12]:
lowest_stats_df_tp, highest_stats_df_tp = get_highest_lowest_metric_indexes(df_tp, stats_metric='prob', stats_topk=10)

#### Highest Scores 

In [13]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, highest_stats_df_tp,\
                           device, max_samples=10, partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
701.0,Hate,Hate (0.96),Hate,2.6,[CLS] @ user ni ##gga are you stupid your trash don ##t play with him play with your bitch : face _ with _ tears _ of _ joy : [SEP]
,,,,,
791.0,Hate,Hate (0.96),Hate,2.33,[CLS] $ 1500 for a phone . you all are fucking dumb . [SEP]
,,,,,
371.0,Hate,Hate (0.96),Hate,2.3,[CLS] and have a bitch thinking you ni ##gga ##s have money and you t ##f don ’ t . [SEP]
,,,,,
730.0,Hate,Hate (0.95),Hate,1.61,[CLS] and she has a pet ? ? ? fucking disgusting ur ##l [SEP]
,,,,,
406.0,Hate,Hate (0.95),Hate,1.68,[CLS] alright let me get right with god bc mother nature is like fuck humans ur ##l [SEP]
,,,,,


Wall time: 3.34 s


## False Positives

In [14]:
lowest_stats_df_fp, highest_stats_df_fp = get_highest_lowest_metric_indexes(df_fp, stats_metric='prob', stats_topk=10)

#### Lowest Scores 

In [15]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, highest_stats_df_fp,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
674.0,Neutral,Hate (0.93),Hate,2.25,[CLS] alex jones be sm ##oki ##n be dick ##s out here but mans really got supporters out here : face _ with _ tears _ of _ joy : [SEP]
,,,,,
716.0,Neutral,Hate (0.93),Hate,1.38,[CLS] @ user i guess that ’ s where swamp ass originated [SEP]
,,,,,
58.0,Neutral,Hate (0.92),Hate,1.15,[CLS] are you fucking serious ? ur ##l [SEP]
,,,,,
599.0,Neutral,Hate (0.92),Hate,1.15,"[CLS] room ##25 is actually incredible , non ##ame is the shit , always has been , and i ’ m see ##in her in like 5 days in melbourne . life is good . have a nice day . [SEP]"
,,,,,
278.0,Neutral,Hate (0.91),Hate,1.32,[CLS] @ user exactly it ’ s bc slick woods has that unconventional look that she ’ s who she is ye ##a she ain ’ t attractive to most of us but that shit don ’ t matter @ all lo ##l [SEP]
,,,,,


Wall time: 3.66 s


## True Negatives

In [16]:
lowest_stats_df_tn, highest_stats_df_tn = get_highest_lowest_metric_indexes(df_tn, stats_metric='prob', stats_topk=10)

#### Highest Scores 

In [17]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, lowest_stats_df_tn,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
44.0,Neutral,Neutral (0.03),Hate,-2.99,[CLS] we ##lov ##ese ##ung ##che ##ol @ user i am happy and proud of the work you have done to train seventeen along with the other members . i see you and you are wonderful and incredible . i really love u [UNK] : two _ hearts : . ur ##l [SEP]
,,,,,
689.0,Neutral,Neutral (0.03),Hate,-0.82,"[CLS] sundays ##ho ##uto ##uts to @ user because . . . . because she is amazing and has my bound ##less support and endless appreciation for her support in return , it ' s just a loop of love and respect … ur ##l [SEP]"
,,,,,
399.0,Neutral,Neutral (0.03),Hate,-1.85,[CLS] @ user $ 6 million buy ##out and she could easily go to cable news if that ’ s she path she chooses . good for her ! [SEP]
,,,,,
496.0,Neutral,Neutral (0.03),Hate,-1.47,[CLS] she ! : herb : mo ##sa ##di ! : herb : she can disc ##ern her season she is a wise builder she is a team player she is full of fire and full of the holy spirit thanks to @ user and dr chant ##al for such an inspiring event : gem _ stone : thank you @ user : gem _ stone : ur ##l [SEP]
,,,,,
434.0,Neutral,Neutral (0.03),Hate,-1.38,"[CLS] my ##pro ##te ##ge , a k ##12 master teacher was very excited to report she is the proud owner and supporter of products made by @ user @ user bravo ! us ##open [SEP]"
,,,,,


Wall time: 4.31 s


## False Negatives

In [18]:
lowest_stats_df_fn, highest_stats_df_fn = get_highest_lowest_metric_indexes(df_fn, stats_metric='prob', stats_topk=10)

#### Lowest Scores 

In [19]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, lowest_stats_df_fn,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
260.0,Hate,Neutral (0.05),Hate,-0.96,[CLS] liberals are reaching peak desperation to call on phillip ##rud ##dock to talk with turnbull to convince him to help with wentworth ##vot ##es 18 sept 2018 @ user aus ##pol l ##np nsw ##pol @ user @ user @ user l ##np ##me ##mes ur ##l [SEP]
,,,,,
429.0,Hate,Neutral (0.06),Hate,-0.46,[CLS] good ##mo ##rn ##inga ##meric ##a tuesday ##mot ##ivation thank ##you father ##go ##d the american ##pe ##ople will begin to see the truth about our government ##cor ##rup ##tion @ user @ user @ user media ##cor ##rup ##tion nor ##eda ##ction dec ##lass ##ify ##fi ##sa dec ##lass ##ification completely obama hillary ##cl ##inton lies leak ##ers liberals ur ##l [SEP]
,,,,,
80.0,Hate,Neutral (0.06),Hate,-0.66,[CLS] @ user anti ##fa has ts level influence . it ' s scary . [SEP]
,,,,,
152.0,Hate,Neutral (0.07),Hate,-0.36,[CLS] sierra ##burg ##ess ##isa ##los ##er she is me when my phone ding ##s : face _ with _ tears _ of _ joy : [SEP]
,,,,,
778.0,Hate,Neutral (0.07),Hate,-0.87,"[CLS] br ##ex ##it deal has been reached - and will be unveiled at special summit in november , has @ user sold out the uk to the eu ? ? ? she better have not or the @ user are finished ! ! @ user ur ##l [SEP]"
,,,,,


Wall time: 3.79 s


# Visualize a sentence by its index

### True Positive

In [20]:
list_indexes = [433, 730, 259, 406]
df_by_indexes = df.iloc[list_indexes]

In [21]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, df_by_indexes,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
433.0,Hate,Hate (0.95),Hate,1.95,[CLS] @ user damn i felt this shit . why you so loud lo ##l [SEP]
,,,,,
730.0,Hate,Hate (0.95),Hate,1.61,[CLS] and she has a pet ? ? ? fucking disgusting ur ##l [SEP]
,,,,,
259.0,Hate,Hate (0.95),Hate,1.5,[CLS] ! ! ! ! bitch i ’ m fucking coming back ur ##l [SEP]
,,,,,
406.0,Hate,Hate (0.95),Hate,1.68,[CLS] alright let me get right with god bc mother nature is like fuck humans ur ##l [SEP]
,,,,,


Wall time: 1.25 s


### False Positive

In [22]:
list_indexes = [674, 599, 278, 700]
df_by_indexes = df.iloc[list_indexes]

In [23]:
%%time
dataset_visualization_bert(interpret_sentence, ig, visualize_text, bert_model_wrapper, df_by_indexes,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
674.0,Neutral,Hate (0.93),Hate,2.25,[CLS] alex jones be sm ##oki ##n be dick ##s out here but mans really got supporters out here : face _ with _ tears _ of _ joy : [SEP]
,,,,,
599.0,Neutral,Hate (0.92),Hate,1.15,"[CLS] room ##25 is actually incredible , non ##ame is the shit , always has been , and i ’ m see ##in her in like 5 days in melbourne . life is good . have a nice day . [SEP]"
,,,,,
278.0,Neutral,Hate (0.91),Hate,1.32,[CLS] @ user exactly it ’ s bc slick woods has that unconventional look that she ’ s who she is ye ##a she ain ’ t attractive to most of us but that shit don ’ t matter @ all lo ##l [SEP]
,,,,,
700.0,Neutral,Hate (0.88),Hate,1.52,[CLS] an american tail really is one of the most under ##rated animation ##s ever ever ever . fuck i cried in this scene [SEP]
,,,,,


Wall time: 1.65 s
