### Imports

In [1]:
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from preprocess_utils import *
from train import train_model, test_model
from models import BasicLSTM, BiLSTM
from test_save_stats import *

from utils import *

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display

spacy_en = spacy.load("en_core_web_sm")

## Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details)

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

### Data Import

In [2]:
## Put your model hyperparameters here
batch_size=1
model_type = 'BiLSTM'
saved_model_path = 'saved_models/BiLSTM_2021-12-05_17-15-21_trained_testAcc=0.5185.pth'

In [3]:
training_data = "data/training_data/offenseval-training-v1.tsv"
testset_data ="data/test_data/testset-levela.tsv"
test_labels_data = "data/test_data/labels-levela.csv"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

field,_, _, _, _ = get_datasets(training_data, testset_data, test_labels_data,model_type,batch_size)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cpu
file loaded and formatted..
data split into train/val/test
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [4]:
print("Loading Model...")
model = load_model(model_type, field, device)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...
saved_models/BiLSTM_2021-12-05_17-15-21_trained_testAcc=0.5185.pth loaded.
Model Loaded.


In [5]:
FILE_PATH = "stats_results/stats_BiLSTM_2021-12-05_17-15-21_test_bcelosswithlogits.csv"

print("Loading Stats Data..")
df = pd.read_csv(FILE_PATH)
df = df.drop(columns=["Unnamed: 0"])
df.head()

Loading Stats Data..


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,<unk> <unk> <unk> <unk> democrats support anti...,1,0,0.115266,2.16051
1,1,"constitutionday is <unk> by conservatives , ha...",0,0,0.413666,0.533867
2,2,foxnews nra maga potus trump <unk> rnc <unk> v...,0,0,0.019785,0.019984
3,3,watching <unk> getting the news that she is st...,0,0,0.079802,0.083167
4,4,<unk> : unity demo to oppose the far - right i...,1,0,0.082469,2.495334


In [6]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


### Definition of methods to Visualize Importance of Words

We modified and adapted code from Captum (in particular, visualization.visualize_text) to fit our context.

In [7]:
def interpret_sentence(model, field, input_data, sentence, vocab_stoi, vocab_itos, device, vis_data_records_ig, \
                       token_reference, lig, min_len = 7, label = 0, class_names=["Neutral","Hate"]):
    
    indexed = [int(input_data[i,0]) for i in range(input_data.shape[0])]
    if len(indexed) < min_len :
        indexed +=[field.vocab.stoi[field.pad_token]] * (min_len - len(indexed))

#     sentence = convert_token_to_str(indexed, vocab_stoi, vocab_itos)
    text = [vocab_itos[tok] for tok in indexed]
    
    if len(text) < min_len:
        text += [field.pad_token] * (min_len - len(text))

    indexed = [vocab_stoi[t] for t in text]
    input_indices = torch.tensor(indexed, device=device).unsqueeze(0).permute(1,0)
#     text = sentence
    model.zero_grad()

#     input_indices = torch.tensor(inputs, device=device)
#     input_indices = input_indices.unsqueeze(0)

    # input_indices dim: [sequence_length]
    seq_length = input_indices.shape[0]
    #seq_length = input_data.shape[0]

   # input_indices = input_data

    # predict
    out = model.forward(input_data)
    out = torch.sigmoid(out)
    pred = out.item()
    pred_ind = round(pred)
    
  
    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0).permute(1, 0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=10, return_convergence_delta=True)

    #print('pred: ', class_names[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records_ig,\
                                  class_names)
    
def add_attributions_to_visualizer(attributions, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records,\
                                   class_names):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
#     print("adding attributions")


    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            class_names[pred_ind],
                            class_names[label],
                            class_names[1],
                            attributions.sum(),
                            text,
                            delta))


In [136]:
def format_classname(classname):

    return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname)

def format_word_importances(words, importances):
    if importances is None or len(importances) == 0:
        return "<td></td>"
    assert len(words) <= len(importances)
    tags = ["<td>"]
    for word, importance in zip(words, importances[: len(words)]):
        word = format_special_tokens(word)
        color = _get_color(importance)
        unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word
        )
        tags.append(unwrapped_tag)
    tags.append("</td>")
    return "".join(tags)

def format_special_tokens(token):
    if token.startswith("<") and token.endswith(">"):
        return "#" + token.strip("<>")
    return token


def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
#     attr = max(-1, min(1, attr))
#     if attr > 0:
#         #(52, 85%, 69%);
#         #yellow
#         hue = 52
#         sat = 85
#         lig = 100 - int(50 * attr)
#     else:
#         # red
#         hue = 0
#         sat = 75
#         lig = 80 - int(-40 * attr)
        
    attr = max(-1, min(1, attr))
    hue = 0
    sat = 75
    lig = np.clip(100 - int(-40 * attr), 0, 100)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def visualize_text(
    datarecords, legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    HAS_IPYTHON = True
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>True Label</th>"
        "<th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
       
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    format_classname(datarecord.true_class),
                    format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    format_classname(datarecord.attr_class),
                    format_classname("{0:.2f}".format(datarecord.attr_score)),
                    format_word_importances(
                        datarecord.raw_input, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1,1], ["Hate","Neutral"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=_get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [137]:
def model_explainability(model, vocab_stoi, vocab_itos, df, max_samples, field, device, class_names=["Neutral","Hate"]):
    """
    Computing words importance for each sample in df
    """
    print("\n\n**MODEL EXPLAINABILITY**\n")
    print("Computing words importance for each sample... ")

    PAD_IND = field.vocab.stoi[field.pad_token] +1 #vocab_stoi[field.pad_token]
    token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

    lig = LayerIntegratedGradients(model, model.emb)

    # accumalate couple samples in this array for visualization purposes
    vis_data_records_ig = []

    phase = "test"
    model.train()
    
    for i in range(max_samples):
        sentence = df.iloc[i].text
        label = df.iloc[i].true_label
        input_tokens = sentence_to_input_tokens(sentence, vocab_stoi)
        with torch.set_grad_enabled(True):
            interpret_sentence(model, field, input_data=input_tokens, sentence=sentence, vocab_stoi=vocab_stoi, \
                               vocab_itos=vocab_itos, device=device, vis_data_records_ig=vis_data_records_ig,\
                               token_reference=token_reference, lig=lig, min_len = 7, label = label, \
                               class_names=class_names)
    
    print("Computations completed.")
    return vis_data_records_ig

def sentence_to_input_tokens(sentence, vocab_stoi):
    input_tokens = []
    for word in sentence.split(" "):
        token = vocab_stoi[word]
        input_tokens.append(token)
    input_tokens= torch.tensor(input_tokens).unsqueeze(0).permute(1, 0)
    return input_tokens

In [138]:
def dataset_visualization(model, vocab_stoi, vocab_itos, df,\
                           field, device, max_samples=10,partial_vis=False,class_names=["Neutral","Hate"]):
    n = len(df)
    if partial_vis:
        n = min(n,max_samples)
    
    vis_data_record = model_explainability(model, vocab_stoi, vocab_itos, df, n,\
                                           field, device, class_names=class_names)
    print("\n\n**LOADING VISUALIZATION**\n")
    visualize_text(vis_data_record)
   

# Data Visualization

We are now going to visualize words' importances in the decision process. <br>
For each category (TP, FP, TN, FN), we visualize importances for both the highest scores and lowest scores.

## True Positives

First we retrieve the highest and lowest scores.

In [139]:
lowest_stats_df_tp, highest_stats_df_tp = get_highest_lowest_metric_indexes(df_tp, stats_metric='prob', stats_topk=5)

#### Highest Scores 

In [140]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Hate (0.97),Hate,0.08,@user oh fuck off : face_with_tears_of_joy : : face_with_tears_of_joy : #unk
,,,,
Hate,Hate (0.97),Hate,0.05,#unk ? fuck you all : middle_finger : : middle_finger : : middle_finger : : middle_finger : : middle_finger : #unk #unk mil baby #unk : #unk : : #unk : : #unk : #unk
,,,,
Hate,Hate (0.97),Hate,1.24,hypocrites democratshateamerica liberalismisamentaldisorder everything liberals touch turns to absolute shit ... url #unk
,,,,
Hate,Hate (0.96),Hate,-0.28,liberals are so pathetic ! innocent until proven guilty mean much ? democrats should look at their own party if they need a rapist to hang ! #unk walkawayfromdemocrats . #unk are straight from the #unk of hell : fire : url #unk
,,,,
Hate,Hate (0.96),Hate,1.46,// so want to kick #unk ass . #unk
,,,,


#### Lowest scores

In [141]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Hate (0.63),Hate,-0.21,#unk - ratings tank as expected . who can possibly stand to watch so many liberals in one place at one time . #unk
,,,,
Hate,Hate (0.64),Hate,-0.26,""" the dregs of society "" "" ... more #unk #unk conservatives . i suggest , we , conservatives keep our ammo dry and our guns #unk .. if the fascist liberal left think conservatives will go #unk into #unk #unk or slaughtered in the streets they need a reality check . url "" #unk"
,,,,
Hate,Hate (0.65),Hate,-0.53,""" #unk #unk - omg , she is crazy #unk - #unk .... yes , i am interested in assistant manager #unk #unk - you guys look like a good pair #unk - are you having some "" "" ? just sit over here #unk - don't do it ! #unk - lets live our own life . lets do it news - collision on #unk , there is #unk #unk - woo - #unk url "" #unk"
,,,,
Hate,Hate (0.72),Hate,-0.01,#unk traitors are worse than fortnite players url conservatives #unk #unk
,,,,
Hate,Hate (0.66),Hate,-0.14,"#unk yup , white #unk privilege at it 's finest . always trying to keep a black #unk down blacklivesmatter qanon antifa #unk infowars #unk #unk #unk podcast radio #unk url #unk"
,,,,


## False Positives

In [142]:
lowest_stats_df_fp, highest_stats_df_fp = get_highest_lowest_metric_indexes(df_fp, stats_metric='prob', stats_topk=5)

#### Highest Scores

In [143]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Hate (0.97),Hate,0.08,@user oh fuck off : face_with_tears_of_joy : : face_with_tears_of_joy : #unk
,,,,
Hate,Hate (0.97),Hate,0.05,#unk ? fuck you all : middle_finger : : middle_finger : : middle_finger : : middle_finger : : middle_finger : #unk #unk mil baby #unk : #unk : : #unk : : #unk : #unk
,,,,
Hate,Hate (0.97),Hate,1.24,hypocrites democratshateamerica liberalismisamentaldisorder everything liberals touch turns to absolute shit ... url #unk
,,,,
Hate,Hate (0.96),Hate,-0.28,liberals are so pathetic ! innocent until proven guilty mean much ? democrats should look at their own party if they need a rapist to hang ! #unk walkawayfromdemocrats . #unk are straight from the #unk of hell : fire : url #unk
,,,,
Hate,Hate (0.96),Hate,1.46,// so want to kick #unk ass . #unk
,,,,


#### Lowest Scores

In [144]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_fp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Hate (0.63),Hate,-0.4,"#unk n his empire : remember one thing , you may own the #unk central govt of india , but you do not own nature . if the life i , my ancestors and parents have lived is true , nature will get you . if you are #unk me or nature is #unk you , only time will tell ! #unk"
,,,,
Neutral,Hate (0.63),Hate,-0.21,"#unk ! ! with #unk events starting in just a week , we of course have to find out what events you are looking forward to most this year . obviously , this is n't all of them . there are so many out there now . url #unk"
,,,,
Neutral,Hate (0.63),Hate,-0.16,#unk security incidents rarely emerge fully formed with #unk lights to alert you . see if you ’re prepared by testing your skills in the following scenario . url url #unk
,,,,
Neutral,Hate (0.63),Hate,-0.28,#unk century hit : do n't #unk by the #unk #unk #unk . #unk #unk . url #unk
,,,,
Neutral,Hate (0.66),Hate,-0.25,"texas wake up ! a socialist is #unk the weak - minded in your state ! you are * * texas * * for crying out loud ! of all the places that #unk the american spirit of individual freedom , texas is the standard #unk make a damn stand for your great senator @user #unk url #unk"
,,,,


## True Negatives

In [145]:
lowest_stats_df_tn, highest_stats_df_tn = get_highest_lowest_metric_indexes(df_tn, stats_metric='prob', stats_topk=5)

#### Highest Scores

In [146]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_tn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Hate (0.62),Hate,-0.09,#unk please share this #unk did nothing except his owner died and he was #unk at #unk he is in grief for his owner and terrified in #unk #unk please #unk is what we need but lets find this poor #unk a #unk url #unk
,,,,
Neutral,Hate (0.62),Hate,-0.23,""" #unk kills #unk in his own apartment ? what went wrong ? ? see : officer who shot man in his own apartment was involved in 2017 shooting of a suspect . “ she is devastated , said a dallas police officer , adds she is "" "" so so sorry for his family ? "" "" url "" #unk"
,,,,
Neutral,Hate (0.61),Hate,0.02,"kavanaugh #unk this started a story he said , she said . now it 's become liberals vs conservatives url #unk"
,,,,
Neutral,Neutral (0.49),Hate,-0.08,#unk i am muslim you are a muslim than why #unk and #unk #unk
,,,,
Neutral,Neutral (0.31),Hate,-0.07,@user be sure to send out the left 's antifa thugs . #unk
,,,,


#### Lowest Scores

In [147]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_tn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Neutral (0.01),Hate,-0.61,@user this girl is so talented and she is my best friend because you and #unk brought us together . please notice this talent . @user @user @user please : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : : purple_heart : : orange_heart : url #unk
,,,,
Neutral,Neutral (0.01),Hate,-0.19,@user she is beyond famous . she is #unk #unk : smiling_face_with_sunglasses : url #unk
,,,,
Neutral,Neutral (0.01),Hate,-0.4,florence looks like she is getting ready to explode . notice the eye starting to #unk . url #unk
,,,,
Neutral,Neutral (0.01),Hate,-0.68,@user his fans are happy . he #unk love and because of him i learned to appreciate and support the lgbtq . i love him because he never failed to make me happy . i love him because he is harry #unk and harry #unk is the best . #unk
,,,,
Neutral,Neutral (0.01),Hate,-0.36,greatawakening qanon patriotsunited wwg1wga #unk when you are awake you see clearly : red_heart : : orange_heart : : yellow_heart : : green_heart : check out this video url url #unk
,,,,


## False Negatives

In [148]:
lowest_stats_df_fn, highest_stats_df_fn = get_highest_lowest_metric_indexes(df_fn, stats_metric='prob', stats_topk=5)

#### Highest Scores

In [149]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_fn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Hate (0.62),Hate,-0.27,theview #unk i do n't believe #unk she is a lying pos ! i live near a police station .. i hope and pray a cop does n't #unk up in my house thinking it 's his / hers and kills me .. my god .. we / people of color are police prey .. #unk always in police cross - #unk .. no matter what #unk
,,,,
Hate,Hate (0.52),Hate,-0.05,antifa is #unk for a #unk terrorist organization url #unk
,,,,
Hate,Hate (0.62),Hate,-0.22,"... what about death _ metal _ ? i want more death metal ... tell me , you hateful #unk #unk , where can i have more death metal ? ? ? ? url #unk"
,,,,
Hate,Hate (0.59),Hate,-0.11,#unk are liberals so metoo only counts when they blame a non - liberal man . #unk kag2018 : red_heart : : united_states : url #unk
,,,,
Hate,Hate (0.57),Hate,-0.06,conservatives are more upset about #unk being gay than having a #unk in the @user a #unk as @user and a rapist as #unk #unk
,,,,


#### Lowest Scores

In [150]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_fn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Neutral (0.01),Hate,-0.11,@user holder needed to be impeached #unk
,,,,
Hate,Neutral (0.01),Hate,-0.21,"4 . ( part 2 ) did n't have to think too hard to come up with #unk . they 're a #unk #unk , #unk #unk , antifa #unk #unk using their #unk powers to take down corrupt #unk companies . they 're a buff lesbian with hearts on their #unk . literally what 's not to love ? #unk"
,,,,
Hate,Neutral (0.01),Hate,-0.13,boycottnfl .@user you are dead to me . @user @user url #unk
,,,,
Hate,Neutral (0.01),Hate,-0.38,"resist traitors caught on video breaking our laws to further their communist socialism ways and slow down the trump agenda . they are the dc #unk of democrat socialists of america . forget antifa , that is a distraction from the real resistance . maga url #unk"
,,,,
Hate,Neutral (0.01),Hate,-0.18,conservatives govt have run up debt in #unk of austerity cuts while the rich have #unk their wealth . #unk url via @user #unk
,,,,
