This notebook computes LRP (Layer-wise Relevance Propagation), SA (Sensitivity Analysis) and GI (GradientxInput) relevances on an exemplary test sentence, and for a chosen relevance *target* class, using a trained bidirectional LSTM, that was trained on the Stanford Sentiment Treebank (SST) dataset.

The LRP implementation is based on the following papers:
- [https://doi.org/10.1371/journal.pone.0130140](https://doi.org/10.1371/journal.pone.0130140)
- [https://doi.org/10.18653/v1/W17-5221](https://doi.org/10.18653/v1/W17-5221)

In [1]:
# -*- coding: utf-8 -*-
from code.LSTM.LSTM_bidi import * 
from code.util.heatmap import html_heatmap

import codecs
import numpy as np
from IPython.display import display, HTML

# Define input sequence and relevance target class

The sentiment classes are encoded the following way:  
**0=very negative, 1=negative, 2=neutral, 3=positive, 4=very positive**

In [2]:
def get_test_sentence(sent_idx):
    """Returns an SST test set sentence and its true label, sent_idx must be an integer in [1, 2210] total 2210 sentences"""
    idx = 1
    with codecs.open("./data/sequence_test.txt", 'r', encoding='utf8') as f:
        for line in f:
            line          = line.rstrip('\n')
            line          = line.split('\t')
            true_class    = int(line[0])-1         # true class
            words         = line[1].split(' | ')   # sentence as list of words
            if idx == sent_idx:#返回 sent_idx 个数的word
                return words, true_class
            idx +=1

def predict(words):
    """Returns the classifier's predicted class"""
    net                 = LSTM_bidi()                                   # load trained LSTM model..from LSTM_bidi.py
    w_indices           = [net.voc.index(w) for w in words] # convert input sentence to word IDs..hz用系统调用号.
    printvar(w_indices, True)
    net.set_input(w_indices) # set LSTM input sequence..from LSTM_bidi.py
    scores              = net.forward() # classification prediction scores.forward(词汇索引):相当于一个五分类的LSTM.
    print("hz- len(scores):%s scores: %s" % (len(scores), scores))
    printvar(net.get_para(), True)
    return np.argmax(scores)  #argmax() return max value index.# 五类情感，选可能性最大的分类      

As an input sequence, either select a sentence from the Stanford Sentiment Treebank (SST) test set, or define your own sequence.

In [3]:
# in "./data/sequence_test.txt"
#  291 1       neither | funny | nor | suspenseful | nor | particularly | well-drawn | . 
words, _ = get_test_sentence(291)                                       # SST test set sentence number 291
printvar(words, True)

hz- varname:default_varname varshape:(8,)  varvalue:['neither', 'funny', 'nor', 'suspenseful', 'nor', 'particularly', 'well-drawn', '.']


In [4]:
# Alternatively, uncomment one of the following sentences, or define your own sequence (only words contained in the vocabulary are supported!)
#words = ['this','movie','was','actually','neither','that','funny',',','nor','super','witty','.']
#words = ['this', 'film', 'does', 'n\'t', 'care', 'about', 'cleverness', ',', 'wit', 'or', 'any', 'other', 'kind', 'of', 'intelligent', 'humor', '.']
#words = ['i','hate','the','movie','though','the','plot','is','interesting','.']
#words = ['used', 'to', 'be', 'my', 'favorite']
#words = ['not', 'worth', 'the', 'time']
#words = ['is', 'n\'t', 'a', 'bad', 'film'] # Note: misclassified sample!
#words = ['is', 'n\'t', 'very', 'interesting'] 
#words = ['it', '\'s', 'easy' ,'to' ,'love' ,'robin' ,'tunney' ,'--' ,'she' ,'\'s' ,'pretty' ,'and' ,'she' ,'can' ,'act' ,'--' ,'but' ,'it' ,'gets' ,'harder' ,'and' ,'harder' ,'to' ,'understand' ,'her' ,'choices', '.']

为了理解单个样本的分类/误分类，我们强烈建议使用分类器的预测类作为相关目标类，因为它是模型最有信心的类，因此此设置将更准确地反映分类器在测试样本上的“观点”。（更一般地说，可以选择任何类作为关联目标类。）
In order to understand the classification/misclassification of single samples, we highly **recommend using the classifier's *predicted* class as the relevance *target* class**, since it's the class the model is the most confident about, and therefore this setup will reflect the classifier's "point of view" on the test sample more accurately.
(More generally, it is possible to choose any class as the relevance *target* class.)

In [5]:
predicted_class = predict(words) # get predicted class..from run_example.ipynb
target_class    = predicted_class  # define relevance target class 

hz- varname:default_varname varshape:(8,)  varvalue:[7931, 984, 4481, 3890, 4481, 5120, 14520, 32]
set_input   self.E.shape: (19538, 60) 
hz- varname:default_varname varshape:(8, 60)  varvalue:[[ 0.43077913 -0.18169117  0.01742873  0.20541596 -0.03969318  0.10358565
   0.05842264  0.15177514 -0.05993763  0.03649869  0.45181909 -0.13671359
  -0.5562824  -0.11577738  0.01810146 -0.10536947  0.07183728  0.05874278
   0.43677098 -0.16668737  0.01021575  0.05963009 -0.03090432 -0.13102242
   0.18614751  0.13063452  0.04894596  0.04302438  0.33930963  0.36994925
  -0.01409349  0.01236203 -0.01782743  0.04562677  0.08322706 -0.09145668
   0.03557287  0.05164042  0.02314351  0.08243123  0.12775126  0.06966085
   0.12123352  0.36728483  0.48453709 -0.1516519   0.03316482 -0.0123687
  -0.32789484  0.02922312 -0.29082173  0.09756328  0.16902158 -0.19181284
  -0.19844265 -0.19888473  0.05569024  0.18603702 -0.48650596 -0.17591265]
 [-0.05923684  0.17109026 -0.27417892  0.20940925 -0.09679801 -0.03

In [6]:
print (words)
print ("\npredicted class:          ",   predicted_class)

['neither', 'funny', 'nor', 'suspenseful', 'nor', 'particularly', 'well-drawn', '.']

predicted class:           0


# Compute LRP relevances

In [7]:
# LRP hyperparameters:
eps                 = 0.001                                             # small positive number
bias_factor         = 0.0                                               # recommended value
 
net                 = LSTM_bidi()                                       # load trained LSTM model

#convert input sentence to word IDs..words=['neither','funny','nor','suspenseful','nor','particularly','well-drawn','.']
w_indices           = [net.voc.index(w) for w in words]
Rx, Rx_rev, R_rest  = net.lrp(w_indices, target_class, eps, bias_factor)# perform LRP
R_words             = np.sum(Rx + Rx_rev, axis=1)  #R表示相关性 compute word-level LRP relevances
scores              = net.s.copy()                                      # classification prediction scores

set_input   self.E.shape: (19538, 60) 
hz- varname:default_varname varshape:(8, 60)  varvalue:[[ 0.43077913 -0.18169117  0.01742873  0.20541596 -0.03969318  0.10358565
   0.05842264  0.15177514 -0.05993763  0.03649869  0.45181909 -0.13671359
  -0.5562824  -0.11577738  0.01810146 -0.10536947  0.07183728  0.05874278
   0.43677098 -0.16668737  0.01021575  0.05963009 -0.03090432 -0.13102242
   0.18614751  0.13063452  0.04894596  0.04302438  0.33930963  0.36994925
  -0.01409349  0.01236203 -0.01782743  0.04562677  0.08322706 -0.09145668
   0.03557287  0.05164042  0.02314351  0.08243123  0.12775126  0.06966085
   0.12123352  0.36728483  0.48453709 -0.1516519   0.03316482 -0.0123687
  -0.32789484  0.02922312 -0.29082173  0.09756328  0.16902158 -0.19181284
  -0.19844265 -0.19888473  0.05569024  0.18603702 -0.48650596 -0.17591265]
 [-0.05923684  0.17109026 -0.27417892  0.20940925 -0.09679801 -0.03086549
   0.21577458  0.02603678 -0.21074523  0.10994164 -0.14057852 -0.17929302
   0.50960153  0.3

In [8]:
print ("prediction scores:        ",   scores)
print ("\nLRP target class:         ", target_class)
print ("\nLRP relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words[idx]) + "\t" + w)
print ("\nLRP heatmap below:")    
display(HTML(html_heatmap(words, R_words)))#from IPython.display import display, HTML
# html_heatmap() from heatmap.py

prediction scores:         [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]

LRP target class:          0

LRP relevances:
			    1.86	neither
			   -1.58	funny
			    1.50	nor
			   -1.54	suspenseful
			    2.00	nor
			   -0.04	particularly
			   -0.06	well-drawn
			   -0.12	.

LRP heatmap below:


In [9]:
# How to sanity check global relevance conservation:
bias_factor        = 1.0                                             # value to use for sanity check
Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)
R_tot              = Rx.sum() + Rx_rev.sum() + R_rest.sum()#total:全部的 # sum of all "input" relevances

printvar(Rx,False)
printvar(Rx_rev, False)
printvar(R_rest, False)
print(R_tot)
print("Sanity check passed? ", np.allclose(R_tot, net.s[target_class]))
#np.allclos():比较两个array是不是每一元素都相等，默认在1e-05的误差范围内

set_input   self.E.shape: (19538, 60) 
hz- varname:default_varname varshape:(8, 60)  varvalue:[[ 0.43077913 -0.18169117  0.01742873  0.20541596 -0.03969318  0.10358565
   0.05842264  0.15177514 -0.05993763  0.03649869  0.45181909 -0.13671359
  -0.5562824  -0.11577738  0.01810146 -0.10536947  0.07183728  0.05874278
   0.43677098 -0.16668737  0.01021575  0.05963009 -0.03090432 -0.13102242
   0.18614751  0.13063452  0.04894596  0.04302438  0.33930963  0.36994925
  -0.01409349  0.01236203 -0.01782743  0.04562677  0.08322706 -0.09145668
   0.03557287  0.05164042  0.02314351  0.08243123  0.12775126  0.06966085
   0.12123352  0.36728483  0.48453709 -0.1516519   0.03316482 -0.0123687
  -0.32789484  0.02922312 -0.29082173  0.09756328  0.16902158 -0.19181284
  -0.19844265 -0.19888473  0.05569024  0.18603702 -0.48650596 -0.17591265]
 [-0.05923684  0.17109026 -0.27417892  0.20940925 -0.09679801 -0.03086549
   0.21577458  0.02603678 -0.21074523  0.10994164 -0.14057852 -0.17929302
   0.50960153  0.3

# Compute SA/GI relevances

In [10]:
net              = LSTM_bidi()                                       # load trained LSTM model

w_indices        = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs
Gx, Gx_rev       = net.backward(w_indices, target_class)             # perform gradient backpropagation
R_words_SA       = (np.linalg.norm(Gx + Gx_rev, ord=2, axis=1))**2   # compute word-level Sensitivity Analysis relevances
R_words_GI       = ((Gx + Gx_rev)*net.x).sum(axis=1)                 # compute word-level GradientxInput relevances

scores           = net.s.copy()                                      # classification prediction scores 

set_input   self.E.shape: (19538, 60) 
hz- varname:default_varname varshape:(8, 60)  varvalue:[[ 0.43077913 -0.18169117  0.01742873  0.20541596 -0.03969318  0.10358565
   0.05842264  0.15177514 -0.05993763  0.03649869  0.45181909 -0.13671359
  -0.5562824  -0.11577738  0.01810146 -0.10536947  0.07183728  0.05874278
   0.43677098 -0.16668737  0.01021575  0.05963009 -0.03090432 -0.13102242
   0.18614751  0.13063452  0.04894596  0.04302438  0.33930963  0.36994925
  -0.01409349  0.01236203 -0.01782743  0.04562677  0.08322706 -0.09145668
   0.03557287  0.05164042  0.02314351  0.08243123  0.12775126  0.06966085
   0.12123352  0.36728483  0.48453709 -0.1516519   0.03316482 -0.0123687
  -0.32789484  0.02922312 -0.29082173  0.09756328  0.16902158 -0.19181284
  -0.19844265 -0.19888473  0.05569024  0.18603702 -0.48650596 -0.17591265]
 [-0.05923684  0.17109026 -0.27417892  0.20940925 -0.09679801 -0.03086549
   0.21577458  0.02603678 -0.21074523  0.10994164 -0.14057852 -0.17929302
   0.50960153  0.3

In [11]:
print ("prediction scores:       ",   scores)
print ("\nSA/GI target class:      ", target_class)
print ("\nSA relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_SA[idx]) + "\t" + w)
print ("\nSA heatmap:")    
display(HTML(html_heatmap(words, R_words_SA)))
print ("\nGI relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_GI[idx]) + "\t" + w)
print ("\nGI heatmap:")    
display(HTML(html_heatmap(words, R_words_GI)))

prediction scores:        [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]

SA/GI target class:       0

SA relevances:
			    5.01	neither
			    0.35	funny
			    0.73	nor
			    0.92	suspenseful
			    1.66	nor
			    0.13	particularly
			    0.66	well-drawn
			    0.32	.

SA heatmap:



GI relevances:
			    0.03	neither
			    0.06	funny
			   -0.11	nor
			   -0.19	suspenseful
			   -0.19	nor
			   -0.07	particularly
			   -0.06	well-drawn
			    0.03	.

GI heatmap:
