### Imports

In [1]:
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from preprocess_utils import *
from train import train_model, test_model
from models import BasicLSTM, BiLSTM
from test_save_stats import *

from utils import *

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display

spacy_en = spacy.load("en_core_web_sm")

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Richard\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## ⚠️ Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details) + any model parameters needed

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

### Hyperparameters

In [2]:
## Put your model hyperparameters here
model_type = 'BasicLSTM'
saved_model_path = 'saved_models/' + 'BasicLSTM_2021-12-08_01-04-25_trained_testAcc=0.7107.pth'
stats_path = "stats_results/" + "stats_BasicLSTM_2021-12-08_01-04-25_test_bcelosswithlogits.csv"

In [3]:
# Specific model parameters
fix_length = None
context_size = 0
pyramid = []
fcs = []
batch_norm = 0
alpha = 0

### Data Import

In [4]:
training_data = "data/training_data/offenseval-training-v1.tsv"
testset_data ="data/test_data/testset-levela.tsv"
test_labels_data = "data/test_data/labels-levela.csv"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

field,_, _, _, _ = get_datasets(training_data, testset_data, test_labels_data, model_type, fix_length)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cuda
file loaded and formatted..
data split into train/val/test
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [5]:
print("Loading Model...")
model = load_model(model_type, field, device, fix_length=fix_length,
                   context_size=context_size, pyramid=pyramid, fcs=fcs,
                   batch_norm=batch_norm, alpha=alpha)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...
saved_models/BasicLSTM_2021-12-08_01-04-25_trained_testAcc=0.7107.pth loaded.
Model Loaded.


In [6]:
print("Loading Stats Data..")
df = pd.read_csv(stats_path)
df = df.drop(columns=["Unnamed: 0"])
df.head()

Loading Stats Data..


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,<unk> <unk> <unk> <unk> democrats support anti...,1,1,0.723092,0.324219
1,1,"constitutionday <unk> conservatives , hated pr...",0,0,0.070297,0.07289
2,2,foxnews nra maga potus trump <unk> rnc <unk> v...,0,0,0.004778,0.00479
3,3,watching <unk> getting news still <unk> always...,0,0,0.020048,0.020251
4,4,<unk> : unity demo oppose far - right london –...,1,1,0.724287,0.322568


In [7]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


### Definition of methods to Visualize Importance of Words

We modified and adapted code from Captum (in particular, visualization.visualize_text) to fit our context.

In [8]:
def interpret_sentence(model, field, pad_ind, input_data, sentence, vocab_stoi, vocab_itos, device, vis_data_records_ig, \
                       token_reference, lig, min_len = 7, label = 0, class_names=["Neutral","Hate"]):
    
    indexed = [int(input_data[i,0]) for i in range(input_data.shape[0])]
    if len(indexed) < min_len :
        indexed +=[pad_ind] * (min_len - len(indexed))

    text = [vocab_itos[tok] for tok in indexed]

    if len(text) < min_len:
        text += [pad_ind] * (min_len - len(text))

    indexed = [vocab_stoi[t] for t in text]
    input_indices = torch.tensor(indexed, device=device).unsqueeze(0).permute(1,0)

    model.zero_grad()

    # input_indices dim: [sequence_length]
    seq_length = input_indices.shape[0]
    #seq_length = input_data.shape[0]

   # input_indices = input_data

    # predict
    out = model.forward(input_data.to(device))
    out = torch.sigmoid(out)
    pred = out.item()
    pred_ind = round(pred)
    
  
    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0).permute(1, 0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices,\
                                           n_steps=200, return_convergence_delta=True)

    #print('pred: ', class_names[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))
    
    """print('text', ' '.join(text))
    attribution_list = attributions_ig.detach().sum(2).squeeze(1)
    list2print = [(text[i], attribution_list.tolist()[i]) for i in range(len(text))]
    print('attribution', list2print)
    attribution_list_norm = attribution_list / torch.norm(attributions_ig)
    list2print_norm = [(text[i], attribution_list_norm.tolist()[i]) for i in range(len(text))]
    print('attribution norm', list2print_norm)
    print()"""

    add_attributions_to_visualizer(attributions_ig, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records_ig,\
                                  class_names)

def add_attributions_to_visualizer(attributions, vocab_itos, text, pred, pred_ind, label, delta, vis_data_records,\
                                   class_names):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            class_names[pred_ind],
                            class_names[label],
                            class_names[1],
                            attributions.sum(),
                            text,
                            delta))


In [9]:
def format_classname(classname):

    return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname)

def format_word_importances(words, importances):
    if importances is None or len(importances) == 0:
        return "<td></td>"
    assert len(words) <= len(importances)
    tags = ["<td>"]
    for word, importance in zip(words, importances[: len(words)]):
        word = format_special_tokens(word)
        color = _get_color(importance)
        unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word
        )
        tags.append(unwrapped_tag)
    tags.append("</td>")
    return "".join(tags)

def format_special_tokens(token):
    if token.startswith("<") and token.endswith(">"):
        return "#" + token.strip("<>")
    return token


def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    if attr > 0:
        #(52, 85%, 69%);
        # red
        hue = 0
        sat = 75
        lig = 100 - int(50 * attr)
    else:
        #yellow
        hue = 52
        sat = 85
        lig = 100 - int(-40 * attr)
        
#     attr = max(-1, min(1, attr))
#     hue = 0
#     sat = 75
#     lig = np.clip(100 - int(50 * attr), 0, 100)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def visualize_text(
    datarecords, legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    HAS_IPYTHON = True
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>True Label</th>"
        "<th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
       
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    format_classname(datarecord.true_class),
                    format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    format_classname(datarecord.attr_class),
                    format_classname("{0:.2f}".format(datarecord.attr_score)),
                    format_word_importances(
                        datarecord.raw_input, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1,1], ["Neutral", "Hate"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=_get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [10]:
def model_explainability(model, vocab_stoi, vocab_itos, df, max_samples, field, device, class_names=["Neutral","Hate"]):
    """
    Computing words importance for each sample in df
    """
    print("\n\n**MODEL EXPLAINABILITY**\n")
    print("Computing words importance for each sample... ")

    PAD_IND = field.vocab.stoi[field.pad_token]-1 #vocab_stoi[field.pad_token]
    token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

    lig = LayerIntegratedGradients(model, model.emb)

    # accumalate couple samples in this array for visualization purposes
    vis_data_records_ig = []

    phase = "test"
    model.train()
    
    for i in range(max_samples):
        sentence = df.iloc[i].text
        label = df.iloc[i].true_label
        input_tokens = sentence_to_input_tokens(sentence, vocab_stoi)
        with torch.set_grad_enabled(True):
            interpret_sentence(model, field, pad_ind = PAD_IND, input_data=input_tokens, sentence=sentence, vocab_stoi=vocab_stoi, \
                               vocab_itos=vocab_itos, device=device, vis_data_records_ig=vis_data_records_ig,\
                               token_reference=token_reference, lig=lig, min_len = 7, label = label, \
                               class_names=class_names)
    
    print("Computations completed.")
    return vis_data_records_ig

def sentence_to_input_tokens(sentence, vocab_stoi):
    input_tokens = []
    for word in sentence.split(" "):
        token = vocab_stoi[word]
        input_tokens.append(token)
    input_tokens= torch.tensor(input_tokens).unsqueeze(0).permute(1, 0)
    return input_tokens

In [11]:
def dataset_visualization(model, vocab_stoi, vocab_itos, df,\
                           field, device, max_samples=10,partial_vis=False,class_names=["Neutral","Hate"]):
    n = len(df)
    if partial_vis:
        n = min(n,max_samples)
    
    vis_data_record = model_explainability(model, vocab_stoi, vocab_itos, df, n,\
                                           field, device, class_names=class_names)
    print("\n\n**LOADING VISUALIZATION**\n")
    visualize_text(vis_data_record)
   

# Data Visualization

We are now going to visualize words' importances in the decision process. <br>
For each category (TP, FP, TN, FN), we visualize importances for both the highest scores and lowest scores.

## True Positives

First we retrieve the highest and lowest scores.

In [12]:
lowest_stats_df_tp, highest_stats_df_tp = get_highest_lowest_metric_indexes(df_tp, stats_metric='prob', stats_topk=10)

#### Highest Scores 

In [13]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Hate (0.98),Hate,0.84,2018 surprised starts top ! ! ! karma bitch watch ! ! ! ! url #unk
,,,,
Hate,Hate (0.98),Hate,-0.43,! ! ! ! bitch ’m fucking coming back url #unk
,,,,
Hate,Hate (0.99),Hate,0.97,alright let get right god bc mother nature like fuck humans url #unk
,,,,
Hate,Hate (0.98),Hate,0.04,@user nigga stupid trash nt play play bitch : face_with_tears_of_joy : #unk
,,,,
Hate,Hate (0.97),Hate,0.46,#unk retweet 30 would like share memes talk dumb shit . cuz baby ’m broke shit love rich url #unk
,,,,


## False Positives

In [14]:
lowest_stats_df_fp, highest_stats_df_fp = get_highest_lowest_metric_indexes(df_fp, stats_metric='prob', stats_topk=10)

#### Highest Scores

In [15]:
dataset_visualization(model, vocab_stoi, vocab_itos, highest_stats_df_fp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Hate (0.98),Hate,2.53,@user @user @user got pretty deep debate friend told #unk trump blacks trump paid supporters : face_with_tears_of_joy : said mean antifa paid domestic terrorist said anti - fascist said fascist kidding ? ! #unk
,,,,
Neutral,Hate (0.97),Hate,2.58,"antifa resist .. trump trying bring world peace , obstruct like democrats .. good also ? haters gon na hate ? url #unk"
,,,,
Neutral,Hate (0.94),Hate,2.41,hiac damn matt hardy #unk orton put one hell cell match ! ! #unk ! ! ! hope okay ! ! #unk
,,,,
Neutral,Hate (0.92),Hate,1.56,democrats liberals accuse brettkavanaugh proof ! ! meanwhile supported #unk run @user called lion senate killed woman . wednesdaywisdom news trump url #unk
,,,,
Neutral,Hate (0.95),Hate,-0.7,fucking serious ? url #unk #unk #unk
,,,,


## True Negatives

In [16]:
lowest_stats_df_tn, highest_stats_df_tn = get_highest_lowest_metric_indexes(df_tn, stats_metric='prob', stats_topk=10)

#### Lowest Scores

In [17]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_tn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neutral,Neutral (0.00),Hate,1.39,"celebrities #unk #unk pregnant ! bachelor star expecting first child boyfriend #unk : #unk #unk whole new reality — going mom ! bachelor #unk , 23 , .. via url url url #unk"
,,,,
Neutral,Neutral (0.00),Hate,0.11,#unk security incidents rarely emerge fully formed #unk lights alert . see ’re prepared testing skills following scenario . url url #unk
,,,,
Neutral,Neutral (0.00),Hate,1.67,"art culture #unk #unk #unk : apology tour : #unk #unk #unk taken social issues . last season 's gun control episode stands shining example , show always ability .. via url url #unk"
,,,,
Neutral,Neutral (0.00),Hate,-0.44,#unk #unk #unk x #unk 2018 f / w @user : link : url : link : url : link : url : link : url url #unk
,,,,
Neutral,Neutral (0.00),Hate,0.06,"#unk definitely tad high puerto rico . democrats provide proof ? providing names , #unk , next #unk n’t hard . hope n’t mind , holding breath . url via @user url #unk"
,,,,


## False Negatives

In [18]:
lowest_stats_df_fn, highest_stats_df_fn = get_highest_lowest_metric_indexes(df_fn, stats_metric='prob', stats_topk=10)

#### Lowest Scores

In [19]:
dataset_visualization(model, vocab_stoi, vocab_itos, lowest_stats_df_fn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Hate,Neutral (0.01),Hate,1.01,"california patriots ... especially #unk obama / soros attempting take house seat get control congress . say #unk #unk . shill candidate #unk #unk #unk . please get vote david #unk house , #unk 21 ! follow @user url #unk"
,,,,
Hate,Neutral (0.02),Hate,2.81,"christian america – go trump ’s example , liberals support open borders , guess conservatives support school shootings . please explain makes america great . #unk"
,,,,
Hate,Neutral (0.01),Hate,0.5,conservatives govt run debt #unk austerity cuts rich #unk wealth . #unk url via @user #unk
,,,,
Hate,Neutral (0.01),Hate,-0.33,"brexit deal reached - #unk special #unk november , @user sold uk eu ? ? ? better @user finished ! ! @user url #unk"
,,,,
Hate,Neutral (0.01),Hate,0.1,nigeria #unk #unk ' incompetent leader nigeria ’s history ' – #unk #unk #unk url #unk via url #unk
,,,,
