# Interpreting BERT Models using XAI

In this notebook we demonstrate use the `Captum` library to interpret some results of our BERT models. 

This notebook was mainly inspired by this [Captum Tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret) and a [GitHub discussion](https://github.com/pytorch/captum/issues/150#issuecomment-549022512).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join('.'))))
if module_path not in sys.path:
    print('Add root path to system path: ', module_path)
    sys.path.append(module_path)
module_path += '/'

Add root path to system path:  D:\Projets\Georgia Tech\Comp Social Science\cs6471-project


In [3]:
import re
import gc
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from src.utils.preprocess_utils import *
from src.training.train_utils import train_model, test_model
from src.evaluation.test_save_stats import *

from src.utils.utils import *
from src.evaluation.xai_utils import *
from src.evaluation.xai_bert_utils import BertModelWrapper, interpret_sentence

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization, IntegratedGradients

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Richard\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## ⚠️ Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details) + any model parameters needed

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

In [4]:
## Put your model hyperparameters here
model_type = 'DistillBert'

# OffensEval
test_dataset_name = 'offenseval'
saved_model_path = module_path + SAVED_MODELS_PATH + 'DistillBert_2022-04-15_02-48-34_trained_testAcc=0.8026.pth'
stats_path = module_path + STATS_CSV + "stats_DistillBert_2022-04-15_02-48-34_test_crossentropy_offenseval.csv"

# Implicit Hate
#test_dataset_name = 'implicithate'
#saved_model_path = module_path + SAVED_MODELS_PATH + 'DistillBert_2022-04-18_02-48-16_trained_testAcc=0.7585.pth'
#stats_path = module_path + STATS_CSV + "stats_DistillBert_2022-04-18_02-48-16_test_crossentropy_implicithate.csv"

# Covid Hate
#test_dataset_name = 'covidhate'
#saved_model_path = module_path + SAVED_MODELS_PATH + 'DistillBert_2022-04-18_02-24-40_trained_testAcc=0.8397.pth'
#stats_path = module_path + STATS_CSV + "stats_DistillBert_2022-04-18_02-24-40_test_crossentropy_covidhate.csv"

In [5]:
# Specific other parameters
fix_length = None

# Get model_id
regex = '\d+-\d+-\d+_\d+-\d+-\d+'
find_list = re.findall(regex, saved_model_path)
assert len(find_list) > 0, "Cannot find model_id (YYYY-MM-DD_HH-MM-SS) in saved_model_path's filename"
model_id = find_list[0]
print("model_id:", model_id)

model_id: 2022-04-15_02-48-34


In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

field, tokenizer, _, _, _ = get_datasets(test_dataset_name, test_dataset_name, model_type, fix_length, 
                                         module_path=module_path)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cuda
pad_index 0
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [7]:
print("Loading Model...")
model = load_model(model_type, field, device)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


D:\Projets\Georgia Tech\Comp Social Science\cs6471-project/saved-models/DistillBert_2022-04-15_02-48-34_trained_testAcc=0.8026.pth loaded.
Model Loaded.


In [8]:
print("Loading Stats Data..")
df = pd.read_csv(stats_path)
df = df.drop(columns=["Unnamed: 0"])
print(df.shape)
df.head()

Loading Stats Data..
(860, 6)


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,[CLS] whoisq wherestheserver dumpnike declasfi...,1,1,0.548699,0.600205
1,1,[CLS] constitutionday is revered by conservati...,0,0,0.432062,0.565743
2,2,[CLS] foxnews nra maga potus trump 2ndamendmen...,0,0,0.091411,0.095862
3,3,[CLS] watching boomer getting the news that sh...,0,0,0.128311,0.137322
4,4,[CLS] nopasaran : unity demo to oppose the far...,1,0,0.123655,2.090264


In [9]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


## Preprocessing and gradient methods

In [10]:
bert_model_wrapper = BertModelWrapper(model).to(device)
ig = IntegratedGradients(bert_model_wrapper)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

sentence, label = "you little chicken", 1

interpret_sentence(bert_model_wrapper, tokenizer=tokenizer, ig=ig, sentence=sentence, label=label, original_idx=0, 
                   vis_data_records_ig=vis_data_records_ig, device=device, raw_text=True)
visualize_text(vis_data_records_ig)

Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,Hate,Hate (0.87),Hate,1.37,[CLS] you little chicken [SEP]
,,,,,


Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,Hate,Hate (0.87),Hate,1.37,[CLS] you little chicken [SEP]
,,,,,


# Data Visualization

We are now going to visualize words' importances in the decision process. <br>
For each category (TP, FP, TN, FN), we visualize importances for both the highest scores and lowest scores.

## True Positives

In [11]:
lowest_stats_df_tp, highest_stats_df_tp = get_highest_lowest_metric_indexes(df_tp, stats_metric='prob', stats_topk=10)

#### Highest Scores 

In [12]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, highest_stats_df_tp,\
                           device, max_samples=10, partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
701.0,Hate,Hate (0.95),Hate,2.17,[CLS] @ user ni ##gga are you stupid your trash don ##t play with him play with your bitch : face _ with _ tears _ of _ joy : [SEP]
,,,,,
791.0,Hate,Hate (0.95),Hate,2.28,[CLS] $ 1500 for a phone . you all are fucking dumb . [SEP]
,,,,,
371.0,Hate,Hate (0.95),Hate,2.61,[CLS] and have a bitch thinking you ni ##gga ##s have money and you t ##f don ’ t . [SEP]
,,,,,
488.0,Hate,Hate (0.95),Hate,0.76,[CLS] @ user fuck [SEP]
,,,,,
259.0,Hate,Hate (0.95),Hate,1.5,[CLS] ! ! ! ! bitch i ’ m fucking coming back ur ##l [SEP]
,,,,,


Wall time: 3.35 s


## False Positives

In [13]:
lowest_stats_df_fp, highest_stats_df_fp = get_highest_lowest_metric_indexes(df_fp, stats_metric='prob', stats_topk=10)

#### Lowest Scores 

In [14]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, highest_stats_df_fp,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
58.0,Neutral,Hate (0.93),Hate,1.29,[CLS] are you fucking serious ? ur ##l [SEP]
,,,,,
674.0,Neutral,Hate (0.92),Hate,1.92,[CLS] alex jones be sm ##oki ##n be dick ##s out here but mans really got supporters out here : face _ with _ tears _ of _ joy : [SEP]
,,,,,
599.0,Neutral,Hate (0.91),Hate,1.17,"[CLS] room ##25 is actually incredible , non ##ame is the shit , always has been , and i ’ m see ##in her in like 5 days in melbourne . life is good . have a nice day . [SEP]"
,,,,,
716.0,Neutral,Hate (0.91),Hate,1.69,[CLS] @ user i guess that ’ s where swamp ass originated [SEP]
,,,,,
700.0,Neutral,Hate (0.90),Hate,1.25,[CLS] an american tail really is one of the most under ##rated animation ##s ever ever ever . fuck i cried in this scene [SEP]
,,,,,


Wall time: 3.85 s


## True Negatives

In [15]:
lowest_stats_df_tn, highest_stats_df_tn = get_highest_lowest_metric_indexes(df_tn, stats_metric='prob', stats_topk=10)

#### Highest Scores 

In [16]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, lowest_stats_df_tn,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
636.0,Neutral,Neutral (0.02),Hate,-1.12,[CLS] @ user @ user there she is [SEP]
,,,,,
337.0,Neutral,Neutral (0.02),Hate,-1.62,[CLS] self ##care ##se ##pt ##em ##ber you are making a difference ! @ user @ user @ user @ user @ user @ user @ user @ user @ user ur ##l [SEP]
,,,,,
715.0,Neutral,Neutral (0.02),Hate,-1.65,[CLS] feels ##good ##girl @ user @ user you should be proud at her she is amazing in ga ##mbit ur ##l [SEP]
,,,,,
496.0,Neutral,Neutral (0.02),Hate,-2.57,[CLS] she ! : herb : mo ##sa ##di ! : herb : she can disc ##ern her season she is a wise builder she is a team player she is full of fire and full of the holy spirit thanks to @ user and dr chant ##al for such an inspiring event : gem _ stone : thank you @ user : gem _ stone : ur ##l [SEP]
,,,,,
44.0,Neutral,Neutral (0.03),Hate,-2.2,[CLS] we ##lov ##ese ##ung ##che ##ol @ user i am happy and proud of the work you have done to train seventeen along with the other members . i see you and you are wonderful and incredible . i really love u [UNK] : two _ hearts : . ur ##l [SEP]
,,,,,


Wall time: 3.97 s


## False Negatives

In [17]:
lowest_stats_df_fn, highest_stats_df_fn = get_highest_lowest_metric_indexes(df_fn, stats_metric='prob', stats_topk=10)

#### Lowest Scores 

In [18]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, lowest_stats_df_fn,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
429.0,Hate,Neutral (0.04),Hate,-1.07,[CLS] good ##mo ##rn ##inga ##meric ##a tuesday ##mot ##ivation thank ##you father ##go ##d the american ##pe ##ople will begin to see the truth about our government ##cor ##rup ##tion @ user @ user @ user media ##cor ##rup ##tion nor ##eda ##ction dec ##lass ##ify ##fi ##sa dec ##lass ##ification completely obama hillary ##cl ##inton lies leak ##ers liberals ur ##l [SEP]
,,,,,
778.0,Hate,Neutral (0.05),Hate,-0.17,"[CLS] br ##ex ##it deal has been reached - and will be unveiled at special summit in november , has @ user sold out the uk to the eu ? ? ? she better have not or the @ user are finished ! ! @ user ur ##l [SEP]"
,,,,,
80.0,Hate,Neutral (0.05),Hate,-0.22,[CLS] @ user anti ##fa has ts level influence . it ' s scary . [SEP]
,,,,,
260.0,Hate,Neutral (0.05),Hate,-0.84,[CLS] liberals are reaching peak desperation to call on phillip ##rud ##dock to talk with turnbull to convince him to help with wentworth ##vot ##es 18 sept 2018 @ user aus ##pol l ##np nsw ##pol @ user @ user @ user l ##np ##me ##mes ur ##l [SEP]
,,,,,
152.0,Hate,Neutral (0.05),Hate,0.02,[CLS] sierra ##burg ##ess ##isa ##los ##er she is me when my phone ding ##s : face _ with _ tears _ of _ joy : [SEP]
,,,,,


Wall time: 3.72 s


# Visualize a sentence by its index

### True Positive

In [19]:
list_indexes = [433, 730, 259, 406]
df_by_indexes = df.iloc[list_indexes]

In [20]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, df_by_indexes,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
433.0,Hate,Hate (0.94),Hate,1.6,[CLS] @ user damn i felt this shit . why you so loud lo ##l [SEP]
,,,,,
730.0,Hate,Hate (0.94),Hate,1.52,[CLS] and she has a pet ? ? ? fucking disgusting ur ##l [SEP]
,,,,,
259.0,Hate,Hate (0.95),Hate,1.5,[CLS] ! ! ! ! bitch i ’ m fucking coming back ur ##l [SEP]
,,,,,
406.0,Hate,Hate (0.94),Hate,1.59,[CLS] alright let me get right with god bc mother nature is like fuck humans ur ##l [SEP]
,,,,,


Wall time: 1.26 s


### False Positive

In [21]:
list_indexes = [674, 599, 278, 700]
df_by_indexes = df.iloc[list_indexes]

In [22]:
%%time
dataset_visualization_bert(interpret_sentence, bert_model_wrapper, tokenizer, ig, visualize_text, df_by_indexes,\
                           device, max_samples=10,partial_vis=True)



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
674.0,Neutral,Hate (0.92),Hate,1.92,[CLS] alex jones be sm ##oki ##n be dick ##s out here but mans really got supporters out here : face _ with _ tears _ of _ joy : [SEP]
,,,,,
599.0,Neutral,Hate (0.91),Hate,1.17,"[CLS] room ##25 is actually incredible , non ##ame is the shit , always has been , and i ’ m see ##in her in like 5 days in melbourne . life is good . have a nice day . [SEP]"
,,,,,
278.0,Neutral,Hate (0.85),Hate,1.19,[CLS] @ user exactly it ’ s bc slick woods has that unconventional look that she ’ s who she is ye ##a she ain ’ t attractive to most of us but that shit don ’ t matter @ all lo ##l [SEP]
,,,,,
700.0,Neutral,Hate (0.90),Hate,1.25,[CLS] an american tail really is one of the most under ##rated animation ##s ever ever ever . fuck i cried in this scene [SEP]
,,,,,


Wall time: 1.62 s
