This notebook computes LRP (Layer-wise Relevance Propagation), SA (Sensitivity Analysis) and GI (GradientxInput) relevances on an exemplary test sentence, and for a chosen relevance *target* class, using a trained bidirectional LSTM, that was trained on the Stanford Sentiment Treebank (SST) dataset.

The LRP implementation is based on the following papers:
- [https://doi.org/10.1371/journal.pone.0130140](https://doi.org/10.1371/journal.pone.0130140)
- [https://doi.org/10.18653/v1/W17-5221](https://doi.org/10.18653/v1/W17-5221)

In [45]:
from LSTM.LSTM_bidi import * 
from heatmap import html_heatmap

import codecs
import numpy as np
from IPython.display import display, HTML

# Define input sequence and relevance target class

The sentiment classes are encoded the following way:  
**0=very negative, 1=negative, 2=neutral, 3=positive, 4=very positive**

In [46]:
def get_test_sentence(sent_idx):
    """Returns an SST test set sentence and its true label, sent_idx must be an integer in [1, 2210]"""
    idx = 1
    with codecs.open("sequence_test.txt", 'r', encoding='utf8') as f:
        for line in f:
            line          = line.rstrip('\n')
            line          = line.rstrip('\r')
            line          = line.split('\t')
            true_class    = int(line[0])-1         # true class
            words         = line[1].split(' | ')   # sentence as list of words
            if idx == sent_idx:
                return words, true_class
            idx +=1

def predict(words):
    """Returns the classifier's predicted class"""
    net                 = LSTM_bidi()                                   # load trained LSTM model
    w_indices           = [net.voc.index(w) for w in words]             # convert input sentence to word IDs
    net.set_input(w_indices)                                            # set LSTM input sequence
    scores              = net.forward()                                 # classification prediction scores
    return np.argmax(scores)            

As an input sequence, either select a sentence from the Stanford Sentiment Treebank (SST) test set, or define your own sequence.

In [47]:
words, _ = get_test_sentence(291)                                       # SST test set sentence number 291

In [48]:
# Alternatively, uncomment one of the following sentences, or define your own sequence (only words contained in the vocabulary are supported!)
#words = ['this','movie','was','actually','neither','that','funny',',','nor','super','witty','.']
#words = ['this', 'film', 'does', 'n\'t', 'care', 'about', 'cleverness', ',', 'wit', 'or', 'any', 'other', 'kind', 'of', 'intelligent', 'humor', '.']
#words = ['i','hate','the','movie','though','the','plot','is','interesting','.']
#words = ['used', 'to', 'be', 'my', 'favorite']
#words = ['not', 'worth', 'the', 'time']
#words = ['is', 'n\'t', 'a', 'bad', 'film'] # Note: misclassified sample!
#words = ['is', 'n\'t', 'very', 'interesting'] 
#words = ['it', '\'s', 'easy' ,'to' ,'love' ,'robin' ,'tunney' ,'--' ,'she' ,'\'s' ,'pretty' ,'and' ,'she' ,'can' ,'act' ,'--' ,'but' ,'it' ,'gets' ,'harder' ,'and' ,'harder' ,'to' ,'understand' ,'her' ,'choices', '.']

In order to understand the classification/misclassification of single samples, we highly **recommend using the classifier's *predicted* class as the relevance *target* class**, since it's the class the model is the most confident about, and therefore this setup will reflect the classifier's "point of view" on the test sample more accurately.
(More generally, it is possible to choose any class as the relevance *target* class.)

In [49]:
predicted_class = predict(words)                                        # get predicted class
target_class    = predicted_class                                       # define relevance target class 

In [50]:
print (words)
print ("\npredicted class:          ",   predicted_class)

['neither', 'funny', 'nor', 'suspenseful', 'nor', 'particularly', 'well-drawn', '.']

predicted class:           0


# Compute LRP relevances

In [51]:
# LRP hyperparameters:
eps                 = 0.001                                             # small positive number
bias_factor         = 0.0                                               # recommended value
 
net                 = LSTM_bidi()                                       # load trained LSTM model

w_indices           = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs
Rx, Rx_rev, R_rest  = net.lrp(w_indices, target_class, eps, bias_factor)# perform LRP
R_words             = np.sum(Rx + Rx_rev, axis=1)                       # compute word-level LRP relevances

scores              = net.s.copy()                                      # classification prediction scores

In [52]:
print ("prediction scores:        ",   scores)
print ("\nLRP target class:         ", target_class)
print ("\nLRP relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words[idx]) + "\t" + w)
print ("\nLRP heatmap:")    
display(HTML(html_heatmap(words, R_words)))

prediction scores:         [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]

LRP target class:          0

LRP relevances:
			    1.86	neither
			   -1.58	funny
			    1.50	nor
			   -1.54	suspenseful
			    2.00	nor
			   -0.04	particularly
			   -0.06	well-drawn
			   -0.12	.

LRP heatmap:


In [53]:
# How to sanity check global relevance conservation:
bias_factor        = 1.0                                             # value to use for sanity check
Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)
R_tot              = Rx.sum() + Rx_rev.sum() + R_rest.sum()          # sum of all "input" relevances

print(R_tot)       ;    print("Sanity check passed? ", np.allclose(R_tot, net.s[target_class]))

2.731496867795989
Sanity check passed?  True


# Compute SA/GI relevances

In [54]:
net              = LSTM_bidi()                                       # load trained LSTM model

w_indices        = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs
Gx, Gx_rev       = net.backward(w_indices, target_class)             # perform gradient backpropagation
R_words_SA       = (np.linalg.norm(Gx + Gx_rev, ord=2, axis=1))**2   # compute word-level Sensitivity Analysis relevances
R_words_GI       = ((Gx + Gx_rev)*net.x).sum(axis=1)                 # compute word-level GradientxInput relevances

scores           = net.s.copy()                                      # classification prediction scores 

In [55]:
print ("prediction scores:       ",   scores)
print ("\nSA/GI target class:      ", target_class)
print ("\nSA relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_SA[idx]) + "\t" + w)
print ("\nSA heatmap:")    
display(HTML(html_heatmap(words, R_words_SA)))
print ("\nGI relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_GI[idx]) + "\t" + w)
print ("\nGI heatmap:")    
display(HTML(html_heatmap(words, R_words_GI)))

prediction scores:        [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]

SA/GI target class:       0

SA relevances:
			    5.01	neither
			    0.35	funny
			    0.73	nor
			    0.92	suspenseful
			    1.66	nor
			    0.13	particularly
			    0.66	well-drawn
			    0.32	.

SA heatmap:



GI relevances:
			    0.03	neither
			    0.06	funny
			   -0.11	nor
			   -0.19	suspenseful
			   -0.19	nor
			   -0.07	particularly
			   -0.06	well-drawn
			    0.03	.

GI heatmap:
