### Imports

In [1]:
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from preprocess_utils import *
from train import train_model, test_model
from models import BasicLSTM, BiLSTM

from utils import *

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display

spacy_en = spacy.load("en_core_web_sm")

## Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details)

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

### Data Import

In [2]:
training_data = "data/training_data/offenseval-training-v1.tsv"
testset_data ="data/test_data/testset-levela.tsv"
test_labels_data = "data/test_data/labels-levela.csv"

# Hyperparameters
batch_size=1
model_type = 'BiLSTM'
saved_model_path = 'saved_models/BiLSTM_2021-12-05_17-15-21_trained_testAcc=0.5185.pth'

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

field, _, _, _ = get_datasets(training_data, testset_data, test_labels_data,model_type=model_type,fix_length=batch_size)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cpu
file loaded and formatted..
data split into train/val/test
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [3]:
print("Loading Model...")
model = load_model(model_type, field, device)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...
saved_models/BiLSTM_2021-12-05_17-15-21_trained_testAcc=0.5185.pth loaded.
Model Loaded.


In [4]:
FILE_PATH = "stats_results/stats_BiLSTM_2021-12-03_23-58-08_test_bcelosswithlogits.csv"

print("Loading Stats Data..")
df = pd.read_csv(FILE_PATH)
df = df.drop(columns=["Unnamed: 0"])
df.head()

Loading Stats Data..


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,<unk> <unk> <unk> <unk> democrats support anti...,1,1,0.90288,0.102165
1,1,"constitutionday is <unk> by conservatives , ha...",0,0,0.446367,0.591253
2,2,foxnews nra maga potus trump <unk> rnc <unk> v...,0,0,0.251962,0.290302
3,3,watching <unk> getting the news that she is st...,0,0,0.005474,0.005489
4,4,<unk> : unity demo to oppose the far - right i...,1,0,0.107203,2.233032


In [172]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


### Definition of methods to Visualize Importance of Words

We modified and adapted code from Captum (in particular, visualization.visualize_text) to fit our context.

In [164]:
def interpret_sentence(model, field, input_data, sentence, vocab_stoi, vocab_itos, device, vis_data_records_ig, \
                       token_reference, lig, min_len = 7, label = 0, class_names=["Neutral","Hate"]):
    
    indexed = [int(input_data[i,0]) for i in range(input_data.shape[0])]
    if len(indexed) < min_len :
        indexed +=[field.vocab.stoi[field.pad_token]] * (min_len - len(indexed))

#     sentence = convert_token_to_str(indexed, vocab_stoi, vocab_itos)
    text = [vocab_itos[tok] for tok in indexed]
    
    if len(text) < min_len:
        text += [field.pad_token] * (min_len - len(text))

    indexed = [vocab_stoi[t] for t in text]
    input_indices = torch.tensor(indexed, device=device).unsqueeze(0).permute(1,0)
#     text = sentence
    model.zero_grad()

#     input_indices = torch.tensor(inputs, device=device)
#     input_indices = input_indices.unsqueeze(0)

    # input_indices dim: [sequence_length]
    seq_length = input_indices.shape[0]
    #seq_length = input_data.shape[0]

   # input_indices = input_data

    # predict
    out = model.forward(input_data)
    out = torch.sigmoid(out)
    pred = out.item()
    pred_ind = round(pred)
    
  
    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0).permute(1, 0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=10, return_convergence_delta=True)

    #print('pred: ', class_names[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records_ig,\
                                  class_names)
    
def add_attributions_to_visualizer(attributions, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records,\
                                   class_names):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
#     print("adding attributions")


    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            class_names[pred_ind],
                            class_names[label],
                            class_names[1],
                            attributions.sum(),
                            text,
                            delta))


In [165]:
def format_classname(classname):

    return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname)

def format_word_importances(words, importances):
    if importances is None or len(importances) == 0:
        return "<td></td>"
    assert len(words) <= len(importances)
    tags = ["<td>"]
    for word, importance in zip(words, importances[: len(words)]):
        word = format_special_tokens(word)
        color = _get_color(importance)
        unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word
        )
        tags.append(unwrapped_tag)
    tags.append("</td>")
    return "".join(tags)

def format_special_tokens(token):
    if token.startswith("<") and token.endswith(">"):
        return "#" + token.strip("<>")
    return token


def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    hue = 0
    sat = 75
    lig = np.clip(100 - int(-40 * attr), 0, 100)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def visualize_text(
    datarecords, legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    HAS_IPYTHON = True
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>True Label</th>"
        "<th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
       
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    format_classname(datarecord.true_class),
                    format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    format_classname(datarecord.attr_class),
                    format_classname("{0:.2f}".format(datarecord.attr_score)),
                    format_word_importances(
                        datarecord.raw_input, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1,1], ["Hate","Neutral"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=_get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [166]:
def model_explainability(model, vocab_stoi, vocab_itos, df, max_samples, field, device, class_names=["Neutral","Hate"]):
    """
    Computing words importance for each sample in df
    """
    print("\n\n**MODEL EXPLAINABILITY**\n")
    print("Computing words importance for each sample... ")

    PAD_IND = field.vocab.stoi[field.pad_token] +1 #vocab_stoi[field.pad_token]
    token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

    lig = LayerIntegratedGradients(model, model.emb)

    # accumalate couple samples in this array for visualization purposes
    vis_data_records_ig = []

    phase = "test"
    model.train()
    
    for i in range(max_samples):
        sentence = df.iloc[i].text
        label = df.iloc[i].true_label
        input_tokens = sentence_to_input_tokens(sentence, vocab_stoi)
        with torch.set_grad_enabled(True):
            interpret_sentence(model, field, input_data=input_tokens, sentence=sentence, vocab_stoi=vocab_stoi, \
                               vocab_itos=vocab_itos, device=device, vis_data_records_ig=vis_data_records_ig,\
                               token_reference=token_reference, lig=lig, min_len = 7, label = label, \
                               class_names=class_names)
    
    print("Computations completed.")
    return vis_data_records_ig

def sentence_to_input_tokens(sentence, vocab_stoi):
    input_tokens = []
    for word in sentence.split(" "):
        token = vocab_stoi[word]
        input_tokens.append(token)
    input_tokens= torch.tensor(input_tokens).unsqueeze(0).permute(1, 0)
    return input_tokens

In [167]:
def dataset_visualization(model, vocab_stoi, vocab_itos, df,\
                           field, device, max_samples=10,partial_vis=False,class_names=["Neutral","Hate"]):
    n = len(df)
    if partial_vis:
        n = max_samples
    
    vis_data_record = model_explainability(model, vocab_stoi, vocab_itos, df, n,\
                                           field, device, class_names=class_names)
    print("\n\n**LOADING VISUALIZATION**\n")
    visualize_text(vis_data_record)
   

## Data Visualization

## True Positives

In [168]:
dataset_visualization(model, vocab_stoi, vocab_itos, df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Neutral (0.11),Hate,-0.5,"#unk #unk #unk #unk democrats support antifa , muslim brotherhood , ms13 , isis , pedophilia , child #unk , taxpayer funded abortion ’s , election fraud , #unk and treason ! ! ! lockthemallup wwg1wga qanon ⁦@user url #unk"
,,,,
Hate,Hate (0.95),Hate,0.67,. . . what the fuck did he do this time ? #unk
,,,,
Hate,Hate (0.96),Hate,0.53,@user nigga #unk da hits at #unk
,,,,
Hate,Hate (0.95),Hate,2.6,stopkavanaugh he is liar like the rest of the gop url #unk
,,,,
Hate,Neutral (0.41),Hate,-0.19,""" @user @user put #unk in a police #unk #unk the bag lady "" "" - she would be picked everytime ! she has to be proof either #unk voters are incompetent to vote or she is part of a vast voter fraud conspiracy ! no one votes for a woman that gross ! "" "" "" #unk"
,,,,


## False Positives

In [169]:
dataset_visualization(model, vocab_stoi, vocab_itos, df_fp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Neutral (0.21),Hate,-0.35,rap is a form of art ! used to express yourself #unk . it does not #unk the green light or excuse the behavior of acting like an animal ! she is not in the streets of the #unk where violence is a way of living . #unk yourself boo and get on @user level for #unk ! queen : crown : #unk
,,,,
Neutral,Hate (0.62),Hate,-0.23,""" #unk kills #unk in his own apartment ? what went wrong ? ? see : officer who shot man in his own apartment was involved in 2017 shooting of a suspect . “ she is devastated , said a dallas police officer , adds she is "" "" so so sorry for his family ? "" "" url "" #unk"
,,,,
Neutral,Neutral (0.49),Hate,0.03,kavanaugh so a wild claim from 36 years ago of groping has #unk into a rape a violent sexual event by move url a soros based #unk . that supports blm antifa etc . #unk ! #unk
,,,,
Neutral,Hate (0.52),Hate,-0.24,hiac damn matt hardy and #unk orton put on one hell in a cell match ! ! #unk ! ! ! i hope he is okay ! ! #unk
,,,,
Neutral,Hate (0.82),Hate,-0.08,@user @user @user are you referring to how they #unk with gun control as their kids get slaughtered in schools ? #unk
,,,,


## True Negatives

In [170]:
dataset_visualization(model, vocab_stoi, vocab_itos, df_tn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Neutral (0.40),Hate,-0.08,"constitutionday is #unk by conservatives , hated by progressives / socialist / democrats that want to change it . #unk"
,,,,
Neutral,Neutral (0.02),Hate,-0.03,"foxnews nra maga potus trump #unk rnc #unk veterans @user @user @user @user @user #unk #unk dnc liberals @user @user #unk @user first , it reduces the ca url #unk"
,,,,
Neutral,Neutral (0.08),Hate,-0.16,watching #unk getting the news that she is still up for #unk always makes me smile . #unk #unk ... @user is such a treasure . url #unk
,,,,
Neutral,Neutral (0.02),Hate,-0.05,5 #unk to #unk audience connection on facebook url @user #unk #unk url #unk
,,,,
Neutral,Neutral (0.07),Hate,-0.33,#unk #unk won the task . she is going to first final list : clapping_hands : : clapping_hands : : clapping_hands : : clapping_hands : #unk
,,,,


## False Negatives

In [171]:
dataset_visualization(model, vocab_stoi, vocab_itos, df_fn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Neutral (0.08),Hate,-0.28,#unk : unity demo to oppose the far - right in london – antifa #unk — enough is enough ! url #unk
,,,,
Hate,Neutral (0.03),Hate,-0.17,@user do you get the feeling he is kissing @user behind so he can #unk him later ? #unk
,,,,
Hate,Neutral (0.03),Hate,-0.55,christineblaseyford is your kavanaugh accuser ... liberals try this every time ... confirmjudgekavanaugh url #unk
,,,,
Hate,Neutral (0.05),Hate,-0.16,@user @user @user @user @user @user @user @user @user i ’m shocked to learn human #unk had guns . some probably illegal too . ca needs more gun control . but do n’t worry about the actual crime . the pic of black guns is worse . #unk
,,,,
Hate,Hate (0.86),Hate,-0.12,conservatives @user - you 're a clown ! url #unk
,,,,
