# Generating White-Box Heatmaps

This notebook illustrates how to generate the heatmaps appearing in the paper.

You will need to import a white-box network, an attribution method, and the function `html_heatmap`.

In [1]:
from models.whitebox import CounterRNN
from attribution import IGAttribution, LRPAttribution
from attribution.src.heatmap import html_heatmap
from IPython.core.display import display, HTML

Attribution scores are produced using attribution objects, which are initialized with a model.

In [2]:
model = CounterRNN()
ig = IGAttribution(model)
lrp = LRPAttribution(model)

You can compute attribution scores by directly calling the attribution object on a string. Use `html_heatmap` to generate a heatmap.

In [3]:
ig_scores = ig("aaabb")
lrp_scores = lrp("aaabb")

display(HTML(html_heatmap("aaabb", ig_scores)))
display(HTML(html_heatmap("aaabb", lrp_scores)))

You can specify a target class using the `target` keyword argument.

In [4]:
ig_scores = ig("aaabb", target=3)
lrp_scores = lrp("aaabb", target=2)

display(HTML(html_heatmap("aaabb", ig_scores)))
display(HTML(html_heatmap("aaabb", lrp_scores)))

Use `model.y_stoi` to see the output class indices and `model.x_stoi` to see the one-hot vector indices.

In [5]:
model.y_stoi

defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x7fa526be5a60>>,
            {'<unk>': 0, '<pad>': 1, 'False': 2, 'True': 3})

Let's see another example.

In [6]:
from models.whitebox import BracketRNN

In [7]:
bracket_model = BracketRNN(50)
bracket_lrp = LRPAttribution(bracket_model)

In [8]:
lrp_scores = bracket_lrp("[[(()")
display(HTML(html_heatmap("[[(()", lrp_scores)))