# Bert Sentiment Test

In [1]:
import sys
sys.path.append("../..")

In [2]:
import numpy as np
import torch
from IPython.core.display import display, HTML
from transformers import BertTokenizer

from modules.lrp_bert_modules import LRPBertForSequenceClassification
from visualization.heatmap import html_heatmap

Load model

In [3]:
print("Loading model...")
config_path = "bert-sst-config.pt"
state_dict_path = "bert-sst.pt"

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = LRPBertForSequenceClassification(torch.load(config_path))
model.load_state_dict(torch.load(state_dict_path))
model.eval()
print("Done.")

Loading model...
Done.


Set up example

In [4]:
test_example = "It's a lovely film with wonderful performances by Buy and " \
               "Accorsi."

Run normal forward pass

In [5]:
model.eval()
inputs = tokenizer(test_example, return_tensors="pt")
logits = model(**inputs).logits.squeeze()
    
classes = ["<unk>", "positive", "negative", "neutral"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
<unk>: -1.5717359781265259
positive: 4.663201332092285
negative: -2.0962235927581787
neutral: 0.019476696848869324


Run attribution forward pass

In [6]:
inputs = tokenizer(test_example, return_tensors="pt")
model.attr()
output = model(**inputs)

print("Attr Forward Pass Output:")
print(output)

Attr Forward Pass Output:
[[-1.5717361   4.6632013  -2.0962234   0.01947674]]


Run LRP

In [7]:
tokens = tokenizer.tokenize(test_example)
rel_y = np.zeros(output.shape)
rel_y[:, 1] = output[:, 1]
rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

LRP Scores:
it: -0.009148062355514672
': 0.00985649027484873
s: -0.0747943180514573
a: 0.04013702449254106
lovely: 0.48535738153882185
film: 0.06023086501874836
with: -0.08558094113996871
wonderful: 0.6364255271844168
performances: -0.06974794971072315
by: 0.16045594305917155
buy: 0.05446581661034733
and: 0.010504063357492247
acc: -0.003813851857354434
##ors: 0.04295817040922448
##i: -0.030305667038154877
.: 0.06092527771290372
Relevance of word embeddings:


Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:
