# Attributions

## Model Wrappers

In order to support a wide variety of backends with different interfaces for their respective models, TruLens uses its own `ModelWrapper` class which provides a general model interface to simplify the implementation of the API functions.
To get the model wrapper, use the `get_model_wrapper` method in `trulens.nn.models`. A model wrapper class exists for each backend that converts a model in the respective backend's format to the general TruLens `ModelWrapper` interface. The wrappers are found in the `models` module, and any model defined using Keras, Pytorch, or Tensorflow should be wrapped before being used with the other API functions that require a model -- all other TruLens functionalities expect models to be an instance of `trulens.nn.models.ModelWrapper`.

For more details on allowed parameters, see the [get_model_wrapper](https://truera.github.io/trulens/api/model_wrappers/) documentation.

For this demo, we will be using a Pytorch model pre-trained on Imagenet. 

In [1]:
import sys
import torch
import numpy as np

from trulens.nn.models import get_model_wrapper

import numpy as np
import matplotlib.pyplot as plt
import PIL

%matplotlib inline

In [263]:
import os
from pathlib import Path
import random
import pandas as pd
import json

from typing import List

from pathlib import Path
import torch
import torch.nn as nn
from torchtext.data import get_tokenizer

class ToySentiment(nn.Module):
    @staticmethod
    def from_pretrained(model_code_path: Path) -> 'ToySentiment':
        # Not strictly necessary to load or save given fixed parameters can be populated, but 
        # implementing this for better integration testing.
        model = ToySentiment()
        model.load_state_dict(torch.load(model_code_path))

    def to_pretrained(self, model_code_path: Path) -> None:
        torch.save(self.state_dict(), model_code_path)

    def set_parameters(self) -> None:
        """Set model parameters as per fixed specification."""

        Wi = torch.zeros_like(self.lstm.weight_ih_l0) + 0.1
        bi = torch.zeros_like(self.lstm.bias_ih_l0) + 0.1
        Wh = torch.zeros_like(self.lstm.weight_hh_l0) + 0.1
        bh = torch.zeros_like(self.lstm.bias_hh_l0) + 0.1

        big = 20.0 # Multipliers to help dealing with LSTM sigmoids.
        half = 8.0 # Intention here is that sigmoid((x*big) - half) is ~0 if x is ~0; and
                    # ~1 when indicator is >~ 1.

        hs = self.hidden_size

        sneutral = 0
        sgood = 1
        sbad = 2
        sconfused = 3

        wneutral = self.vocab['neutral']
        wgood = self.vocab['good']
        wbad = self.vocab['bad']
        
        # make sure c gate is always big
        bi[0:hs*3] = 100.0
        bh[0:hs*3] = 100.0

        # o gate weights:
        Wi[3*hs,   wneutral] = big # read neutral word
        Wi[3*hs+1, wgood] = big # read good word
        Wi[3*hs+2, wbad] = big # read bad word
        Wh[3*hs,   sneutral] = big # keep prior neutral, good, bad states
        Wh[3*hs+1, sgood] = big # 
        Wh[3*hs+2, sbad] = big #
        bi[3*hs:4*hs] = -half # sigmoid will be 0 unless one of the three words was read

        # set "good to bad" confused if prior was good, and input was bad
        Wh[3*hs+3, sgood] = big    # (prior state was good
        Wi[3*hs+3, wbad] = big    #  and input was bad)
        Wh[3*hs+3, sconfused] = 2*big  # or (was already in this confused state)
        bh[3*hs+3] = -(half*2) # Want at least 2 of first two to fire, or just the last one to fire.

        # set "bad to good" confused if prior was bad, and input was good
        Wh[3*hs+4, sbad] = big     # (prior state was bad
        Wi[3*hs+4, wgood] = big     #  and input was good)
        Wh[3*hs+4, sconfused] = 2*big   # or (was already confused)
        bh[3*hs+4] = -(half*2)  #

        self.lstm.weight_hh_l0 = nn.Parameter(Wh)
        self.lstm.bias_hh_l0 = nn.Parameter(bh)
        self.lstm.weight_ih_l0 = nn.Parameter(Wi)
        self.lstm.bias_ih_l0 = nn.Parameter(bi)

        self.embedding.weight = nn.Parameter(torch.eye(self.emb_size))

        self.lin.weight = nn.Parameter(torch.tensor(
            [
                [10.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 20.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 20.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 30.0, 30.0]
            ]
        ))

    def __init__(self):
        super().__init__()
        
        self.tokenizer = get_tokenizer("basic_english")

        self.vocab = {"[UNKNOWN]": 0, "neutral": 1, "good": 2, "bad": 3}

        self.emb_size = len(self.vocab)

        self.hidden_size = 5
        # 5 states, one for neutral, one for positive, one for negative, and two for confused. Requiring two
        # confused states for simplicity of the model; it is easier to encode semantics of confusion based on 
        # which initial positive/negative state was set first.

        # Identity embedding, each vocab word has its own dimension where its presence is encoded.
        self.embedding = nn.Embedding(
            # padding_idx=0, 
            embedding_dim=self.emb_size, 
            num_embeddings=self.emb_size,
            padding_idx=None,
            max_norm=None,
            norm_type=None
        )

        self.lstm = nn.LSTM(
            input_size=self.emb_size,
            hidden_size=self.hidden_size,
            num_layers=1,
            batch_first=True
        )

        # Linear layer to combine the two types of confused state and weight things so that
        # confused outweighs positive and negative, while positive and negative outweigh neutral 
        # if more than one of these states is set.
        self.lin = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=4,
            bias=False
        )

        # self.sigmoid = torch.nn.Sigmoid()

        # Finally add a softmax for classification.
        self.softmax = torch.nn.Softmax(dim=1)
        # self.softmax = torch.nn.LogSoftmax(dim=1)

        self.set_parameters()

    def forward(self, word_ids: torch.Tensor, embeds=None) -> torch.Tensor:
        if word_ids is not None:
            batch_size = word_ids.shape[0]
        else:
            batch_size = embeds.shape[0]

        h0 = torch.zeros(1, batch_size, self.hidden_size)
        h0[0, 0, 0] = 1.0 # initial state is neutral
        c0 = torch.zeros(1, batch_size, self.hidden_size)

        if embeds is None:
            embeds = self.embedding(word_ids)
            
        embeds.retain_grad()

        out, (hn, cn) = self.lstm(embeds, (h0, c0))
        hn = hn[0,:,:]

        lin_out = self.lin(hn)

        probits = self.softmax(lin_out)
        # preds = torch.argmax(probits, axis=1)
        # pred_prob = torch.gather(probits, dim=1, index=preds[:,None])
        # pred_prob.backward(torch.ones_like(pred_prob))
        # attr = embeds.grad
        # return preds, probits, attr

        return probits


    def input_of_text(self, texts: List[str]) -> torch.Tensor:
        tokens = [self.tokenizer(text) for text in texts]

        word_ids = [[(self.vocab[t] if t in self.vocab else 0) for t in token] for token in tokens]

        return torch.Tensor(word_ids).int()


def generate_dataset(n, l):
    """Generate random sentiment sentences and their labels."""

    ret = []
    cls = []
    for i in range(n):
        sent = []
        while len(sent) < l:
            r = random.random()

            word = "neutral"

            if r > 0.9 and len(sent) > 0:
                continue
            elif r > 0.8:
                word = "good"
            elif r > 0.7:
                word = "bad"

            sent.append(word)

        ret.append(" ".join(sent))

        gt = 0 # neutral
        if "good" in sent and "bad" in sent:
            gt = 3 # confused
        elif "good" in sent:
            gt = 1 # positive
        elif "bad" in sent:
            gt = 2 # negative

        cls.append(gt)

    return pd.DataFrame(dict(sentence=ret, sentiment=cls))

In [342]:
pt = ToySentiment()
device = 'cpu'
# Produce a wrapped model from the pytorch model.
model = get_model_wrapper(pt, input_shape=(4), device=device, input_dtype=int, )

# test_data = generate_dataset(5, 4)

x_sentences = test_data['sentence']
x_tokens = [pt.tokenizer(s) for s in x_sentences]
x_ids = pt.input_of_text(x_sentences)
labels = test_data['sentiment']

# TODO: define "EmbeddedInfluence" for InputInfluence except with an initial embedding layer and initialize internal influence as below:
from trulens.nn.quantities import LambdaQoI, ClassQoI
from IPython.display import HTML, display
import tabulate

cls_names = {0:"neutral", 1:"positive", 2:"negative", 3:"confused"}

infl = {}
attrs = {}
for cls in [0,1,2,3]:
    infl[cls] = InternalInfluence(
        model,
        cuts=(Cut("embedding"), Cut("softmax")),
        #doi=PointDoi(cut=Cut("embedding")),
        doi=LinearDoi(cut=Cut("embedding"), resolution=100),
        qoi=ClassQoI(cls),
        multiply_activation=False
    )
    attrs[cls] = infl[cls].attributions(x_ids)


preds = pt(x_ids)

data = []

for i, (words, label, pred) in enumerate(zip(x_tokens, labels, preds)):
    data.append([f"<b>GT={cls_names[label]} PRED={cls_names[pred.argmax().item()]}</b>"] + [" ".join(words)])

    for cls in [0,1,2,3]:
        # print(f"{cls_names[cls]} | ", end='')

        row = [cls_names[cls]]

        sent = ""

        for word, attr in zip(words, attrs[cls][i]):
            # print(f"{word}[{attr.sum()}] ", end="")
            mag = attr.sum()
            red = 0.0
            green = 0.0
            if mag > 0:
                green = 1.0 # 0.5 + mag * 0.5
                red = 1.0 - mag * 0.5
            else:
                red = 1.0
                green = 1.0 + mag * 0.5
                #red = 0.5 - mag * 0.5

            blue = min(red, green)
            # blue = 1.0 - max(red, green)

            sent += f"<span style='color: rgb({red*255}, {green*255}, {blue*255});'>{word}</span> "

        row.append(sent)
        data.append(row)

        # print()

tab = tabulate.tabulate(data, tablefmt='html')
display(HTML(tab))

INFO: Detected pytorch backend for <class '__main__.ToySentiment'>.
INFO: Using backend Backend.PYTORCH.
INFO: If this seems incorrect, you can force the correct backend by passing the `backend` parameter directly into your get_model_wrapper call.


0,1
GT=neutral PRED=neutral,neutral neutral neutral neutral
neutral,neutral neutral neutral neutral
positive,neutral neutral neutral neutral
negative,neutral neutral neutral neutral
confused,neutral neutral neutral neutral
GT=positive PRED=positive,good neutral neutral neutral
neutral,good neutral neutral neutral
positive,good neutral neutral neutral
negative,good neutral neutral neutral
confused,good neutral neutral neutral


In [295]:
test_data['sentence'][0] = "neutral neutral neutral neutral"
test_data['sentence'][1] = "good neutral neutral neutral"
test_data['sentence'][2] = "neutral good neutral neutral"
test_data['sentence'][3] = "neutral neutral good neutral"
test_data['sentence'][4] = "neutral neutral neutral good"
test_data['sentiment'][0] = 0
test_data['sentiment'][1:] = 1

test_data

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  after removing the cwd from sys.path.
A

Unnamed: 0,sentence,sentiment
0,neutral neutral neutral neutral,0
1,good neutral neutral neutral,1
2,neutral good neutral neutral,1
3,neutral neutral good neutral,1
4,neutral neutral neutral good,1


In [240]:
test_data = generate_dataset(4, 2)
pm = ToySentiment()
pm.requires_grad_(True)
pm.train()

word_ids = pm.input_of_text(test_data['sentence'])
# embeds = pm.embedding(word_ids)
# embeds.retain_grad()

probits = pm(word_ids=word_ids)

print(probits)

tensor([[9.9980e-01, 6.5042e-05, 6.5042e-05, 6.5042e-05],
        [9.9980e-01, 6.5042e-05, 6.5042e-05, 6.5042e-05],
        [4.2322e-09, 1.0000e+00, 4.2322e-09, 4.2322e-09],
        [6.5051e-05, 4.2319e-09, 9.9993e-01, 4.2319e-09]],
       grad_fn=<SoftmaxBackward0>)


In [180]:
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification

# Initializing a DistilBERT configuration
# configuration = DistilBertConfig()

# Initializing a model from the configuration
#bert = DistilBertModel(configuration)
bert = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

# Accessing the model configuration
# configuration = bert.config

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier

In [181]:
bt = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [182]:
sentences = [["this is a sentence"], ["and so is this"]]
#ins = [bt(s) for s in sentences]
#input_ids = torch.Tensor([i['input_ids'] for i in ins]).long()
#input_ids
inputs = bt("this is a sentence", return_tensors="pt")
# bt(sentences)
inputs

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 6251,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [183]:
bert.distilbert.embeddings.word_embeddings.weight

Parameter containing:
tensor([[-0.0166, -0.0666, -0.0163,  ..., -0.0200, -0.0514, -0.0264],
        [-0.0132, -0.0673, -0.0161,  ..., -0.0227, -0.0554, -0.0260],
        [-0.0176, -0.0709, -0.0144,  ..., -0.0246, -0.0596, -0.0232],
        ...,
        [-0.0231, -0.0588, -0.0105,  ..., -0.0195, -0.0262, -0.0212],
        [-0.0490, -0.0561, -0.0047,  ..., -0.0107, -0.0180, -0.0219],
        [-0.0065, -0.0915, -0.0025,  ..., -0.0151, -0.0504,  0.0460]],
       requires_grad=True)

In [184]:
# inputs
bert.eval()
bert.requires_grad_(True)
# bert.retain_grad()
embeds = bert.distilbert.embeddings.word_embeddings(inputs['input_ids'])
embeds.retain_grad()
# embeds
outs = bert(inputs_embeds=embeds)
outs.logits.abs().sum().backward()
attr = embeds.grad

In [185]:
attr

tensor([[[ 2.7252e-05,  4.6484e-03, -3.0422e-02,  ...,  5.2301e-03,
           3.6593e-02,  5.0404e-02],
         [-6.2519e-03,  1.0962e-02, -4.2026e-03,  ...,  1.3291e-03,
           1.5315e-02,  1.0347e-02],
         [-2.8572e-03,  1.0138e-02, -3.9614e-03,  ...,  3.5016e-03,
           1.5171e-02,  9.4054e-03],
         [-5.5624e-03,  9.6869e-03, -3.6593e-03,  ...,  2.0852e-03,
           1.4146e-02,  1.0402e-02],
         [-7.1661e-03,  1.0525e-02, -1.9546e-03,  ...,  7.5274e-03,
           2.0054e-02,  1.2984e-02],
         [ 3.3965e-02,  1.6431e-02, -1.2048e-03,  ...,  5.7613e-02,
           8.7031e-03,  9.5926e-03]]])

INFO: Detected pytorch backend for <class '__main__.ToySentiment'>.
INFO: Using backend Backend.PYTORCH.
INFO: If this seems incorrect, you can force the correct backend by passing the `backend` parameter directly into your get_model_wrapper call.


In [200]:
model.print_layer_names()

'embedding':	Embedding(4, 4, norm_type=None)
'lstm':	LSTM(4, 5, batch_first=True)
'lin':	Linear(in_features=5, out_features=4, bias=False)
'softmax':	Softmax(dim=1)


In [201]:
m = ToySentiment()
m.requires_grad_(True)

test_data = generate_dataset(10, 4)

for sentence, sentiment in zip(test_data['sentence'], test_data['sentiment']):
    word_ids = m.input_of_text([sentence])
    pred = m(word_ids)
    print(sentence, "\t", word_ids, "\t", "prediction=", pred, "\t", "label=", sentiment)
    # print(attr)

neutral good bad neutral 	 tensor([[1, 2, 3, 1]], dtype=torch.int32) 	 prediction= tensor([[1.9945e-22, 4.3440e-18, 4.3441e-18, 1.0000e+00]],
       grad_fn=<SoftmaxBackward0>) 	 label= 3
neutral neutral neutral neutral 	 tensor([[1, 1, 1, 1]], dtype=torch.int32) 	 prediction= tensor([[9.9986e-01, 4.5707e-05, 4.5707e-05, 4.5699e-05]],
       grad_fn=<SoftmaxBackward0>) 	 label= 0
neutral neutral neutral neutral 	 tensor([[1, 1, 1, 1]], dtype=torch.int32) 	 prediction= tensor([[9.9986e-01, 4.5707e-05, 4.5707e-05, 4.5699e-05]],
       grad_fn=<SoftmaxBackward0>) 	 label= 0
neutral neutral neutral neutral 	 tensor([[1, 1, 1, 1]], dtype=torch.int32) 	 prediction= tensor([[9.9986e-01, 4.5707e-05, 4.5707e-05, 4.5699e-05]],
       grad_fn=<SoftmaxBackward0>) 	 label= 0
neutral neutral good good 	 tensor([[1, 1, 2, 2]], dtype=torch.int32) 	 prediction= tensor([[4.5577e-05, 9.9995e-01, 2.0893e-09, 2.0889e-09]],
       grad_fn=<SoftmaxBackward0>) 	 label= 1
neutral neutral neutral neutral 	 tens

In [202]:
from trulens.nn.attribution import InputAttribution, InternalInfluence
from trulens.nn.attribution import IntegratedGradients
from trulens.nn.attribution import Cut, InputCut, OutputCut
from trulens.nn.distributions import LinearDoi, PointDoi

Saliency maps are implemented by the `InputAttribution` class. This takes several optional arguments, the meaning of which we will discuss later in this notebook. The provided defaults instantiate an `AttributionMethod` that is consistent with the method described in the reference above.

The required argument to the constructor is a `ModelWrapper`. After constructing the attribution method, we call it on our data point, and receive an array containing the attributions.

0->0 neutral[-0.003716949839144945] neutral[-0.00554325757548213] neutral[-0.004302475601434708] neutral[-0.009433472529053688] 
0->0 neutral[-0.003716949839144945] neutral[-0.00554325757548213] neutral[-0.004302475601434708] neutral[-0.009433472529053688] 
2->2 neutral[-0.004279387649148703] neutral[-0.006282697431743145] neutral[-0.003931532613933086] bad[-0.16649414598941803] 
1->1 neutral[0.004782388918101788] neutral[0.007022105157375336] neutral[0.004406385123729706] good[0.8117256164550781] 
1->1 neutral[0.004782388918101788] neutral[0.007022105157375336] neutral[0.004406385123729706] good[0.8117256164550781] 
2->2 neutral[-6.288810254773125e-05] neutral[-9.242133819498122e-05] bad[-0.0003415872051846236] neutral[-0.07553716003894806] 
0->0 neutral[-0.003716949839144945] neutral[-0.00554325757548213] neutral[-0.004302475601434708] neutral[-0.009433472529053688] 
0->0 neutral[-0.003716949839144945] neutral[-0.00554325757548213] neutral[-0.004302475601434708] neutral[-0.0094334725

To visualize the attributions, we can use `MaskVisualizer` from the `visualizations` module. This class takes a `blur` and `threshold` argument, and allows us to overlay a partially-opaque mask over a given image that reveals the top-threshold percentage of pixels by attribution, after applying a Gaussian blur of the given radius.

In [None]:
from trulens.visualizations import MaskVisualizer

In [None]:
masked_image = MaskVisualizer(blur=5, threshold=0.95)(attrs_input, x)

Turning to Integrated Gradients, the workflow for obtaining attributions is nearly identical. The only difference is that the `AttributionMethod` instance we construct is one of `IntegratedGradients` rather than `InputGradients`.

The `resolution` argument, which is optional and defaults to 50, specifies the number of samples to take. Larger values approximate the true aggregate quantity more closely, and in practical terms, tend to be more stable.

Another optional argument is the `baseline`, which specifies the linear subspace over which the quantity is aggregated. By default this is `None`, which is interpreted as an appropriately-sized zero tensor. In this case the linear subspace is the line between this zero tensor, and the point for which the attributions are computed.

In [None]:
infl = IntegratedGradients(model, resolution=10)
attrs_input = infl.attributions(x_pp)

Visualizing the results, it is apparent that Integrated Gradients in this case is better able to focus on the pixels corresponding to the beagle, which is consistent with the model's top predicted class.

In [None]:
masked_image = MaskVisualizer(blur=5, threshold=0.95)(attrs_input, x)

### Discovering Important Internal Neurons 

Now we'll examine is *Internal Influence* (Leino et al.), a powerful and general attribution method that can calculate attributions for internal neurons in a network as well as for the inputs to the network. Internal Influence is implemented by the `InternalInfluence` class in the `attribution` module.

The `InternalInfluence` constructor takes a TruLens `ModelWrapper` and three special parameters: a *slice*, a *quantity of interest* (QoI), and a *distribution of interest* (DoI), which are instances of the `Slice` (in the `slices` module), `QoI` (in the `quantities` module), and `DoI` (in the `distributions` module) classes, respectively.

The slice essentially defines a layer to use for internal attributions. A `Slice` object specifies two `Cut`s corresponding to two layers: (1) the layer of the variables that we are calculating attribution *for* (e.g., the input layer), and (2) the layer whose output defines our quantity of interest (e.g., the output layer; see below for more on quantities of interest).

The shape of the attributions will always match the shape of the first cut. In the case of `InputAttribution`, it is the shape of the input. For neuron explanations, the attributions can take the shape of the output or input of a specific network layer. The default behavior is to create attributions for the output of a layer, but this can be specified via the `anchor` in the `Cut` class. See the [Slice](https://truera.github.io/trulens/api/slices/) documentation for more detail.

The quantity of interest (QoI) essentially defines the model behavior we would like to explain using attributions. The QoI is a function of the model's output at some layer. For example, it may select the confidence score for a particular class. In its most general form, the QoI can be pecified by an implementation of the `QoI` class in the `quantities` module. Several common default implementations are provided in this module as well.

The distribution of interest (DoI) essentially specifies points surrounding each record for which the calculated attribution should be faithful. The distribution can be specified via an implementation of the `DoI` class in the `distributions` module, which is a function taking an input record and producing a list of input points to aggregate attribution over. A few common default distributions implementing the `DoI` class can be found in the `distributions` module.

---

* Klas Leino, Shayak Sen, Anupam Datta, Matt Fredrikson, and Linyi Li. *Influence-Directed Explanations for Deep Convolutional Networks*. IEEE ITC 2018. [ArXiv](https://arxiv.org/pdf/1802.03788.pdf)

In [None]:
from trulens.nn.attribution import InternalInfluence
from trulens.nn.distributions import PointDoi
from trulens.nn.quantities import ClassQoI, InternalChannelQoI, MaxClassQoI
from trulens.nn.slices import Cut, InputCut, OutputCut, Slice

We will be calculating attributions for the feature maps in the layer labeled `'features_28'` (specified via the slice below). In our first example, we are interested in explaining the model's *predicted class* for our record. We specify this by using a `MaxClassQoI`, which sets the attributions to explain the model's output for its highest-confidence class. We will initially use the `PointDoI` which specifies that we are only concerned with the model's behavior on one particular point, i.e., we want a very *local* explanation.

In [None]:
# Define the influence measure.
infl = InternalInfluence(
    model, 
    Slice(Cut('features_28'), OutputCut()), 
    MaxClassQoI(),
    PointDoi())

# Get the attributions for the internal neurons at layer -10. Because layer -10
# contains 2D feature maps, we take the sum over the width and height of the 
# feature maps to obtain a single attribution for each feature map.
attrs_internal = infl.attributions(x_pp).sum(axis=(2,3))

Note that above we used the `Slice`, `MaxClassQoI`, and `PointDoI` classes to define the slice, QoI, and DoI. The TruLens API also offers several simple shorthands for defining these parameters more simply. For example, the above code could be more succinctly written as

In [None]:
# Define the influence measure.
infl = InternalInfluence(model, 'features_28', 'max', 'point')

# Get the attributions for the internal neurons at layer 'features_28'. Because 
# layer 'features_28' contains 2D feature maps, we take the sum over the width 
# and height of the feature maps to obtain a single attribution for each feature 
# map.
attrs_internal = infl.attributions(x_pp).sum(axis=(2,3))

Now we can calculate the most important feature map towards the model's top prediction, by taking the argmax over the internal attributions for this record. The most important feature map represents some type of *learned feature* that was the *most important* in the network's decision to label this point as `'beagle'`.

In [None]:
top_feature_map = int(attrs_internal[0].argmax())

print('Top feature map:', top_feature_map)

### Visualizing Important Internal Neurons

We would now like to visualize our identified feature map in a meaningful way. Since the feature map represents a learned feature, which is not readily interpretable, we will use a second set of attributions to identify the input features that are most important in defining this particular feature map. We will then use a *visualizer*, found in the `visualizations` module, to visualize these input features that relate to our identified important feature map.

In [None]:
from trulens.visualizations import MaskVisualizer

First, we create another attributer, again using `InternalInfluence`. This time, we specify our slice to begin at the input of the model and end at layer `'features_28`, the layer of our identified feature map. We select our quantity of interest to be an `InternalChannelQoI` - this allows us to specify a particular channel that we want to calculate attributions towards (in our case we specify this channel to be our identified feature map). We will again use the `Point` DoI.

Note that if we simply give the top feature map to `InternalInfluence`, it will automatically wrap it in an instance of `InternalChannelQoI` for us; additionally, the `Slice` object is inferred from the tuple of cuts.

In [None]:
infl_input = InternalInfluence(
    model, 
    Slice(InputCut(), Cut('features_28')),
    InternalChannelQoI(top_feature_map), 
    PointDoi())

Again, the above code can be simplified to

In [None]:
infl_input = InternalInfluence(
    model, 
    (InputCut(), Cut('features_28')), 
    top_feature_map, 
    'point')

Now we can calculate the input attributions and visualize the top feature map by using the input attributions as a mask over the original image, using the `MaskVisualizer` (found in the `visualizations` module).

The `MaskVisualizer` takes two fine-tuning parameters, `blur` and `threshold`, that affect the quality of the visualization. The attributions are first blurred using a Gaussian blur with radius `blur`, and then only the pixels whose blurred attribution value are at or above the percentile given by `threshold` are highlighted. Depending on the particular record and application, different `blur` and `threshold` parameters may be appropriate.

Increasing `blur` gives a more abstract, region-focused explanation, while a smaller blur gives a noisier, but more precise explanation.

Increasing `threshold` selects a smaller portion of the image to highlight, showing only the most important regions, while a smaller threshold will highlight a larger portion of the image.

In [None]:
attrs_input = infl_input.attributions(x_pp)

masked_image = MaskVisualizer(blur=10, threshold=0.95)(attrs_input, x)

The above procedure &mdash; using a second set of attributions to identify the input features that are most important in defining a particular feature map, then using a visualizer on the resulting input attributions &mdash; is a common use case when dealing with internal attributions. This procedure can instead be done via a single step, using a `ChannelMaskVisualizer` also found in the `visualizations` module, demonstrated below.

In [None]:
from trulens.visualizations import ChannelMaskVisualizer

masked_image = ChannelMaskVisualizer(
    model,
    'features_28',
    top_feature_map,
    blur=10,
    threshold=0.95)(x, x_pp)

plt.axis('off')
plt.imshow(masked_image[0].transpose((1,2,0)))

#### Other Quantities of Interest

We can also change the quantity that we want the attributions to explain. For example, our example image contains both a bike and a dog. Recall that while the top class predicted by our model was `'beagle'`, imagenet also contains bike-related classes, e.g., `'mountain bike, all-terrain bike, off-roader'`. We will use the `ClassQoI` to view the attributions towards the class `'mountain bike, all-terrain bike, off-roader'`.

In [None]:
# Define the influence measure.
infl_bike = InternalInfluence(model, 'features_28', 671, 'point')

# The above is shorthand for
#
# infl_bike = InternalInfluence(
#     model, 
#     Slice(Cut('features_28', OutputCut()),
#     ClassQoI(671), 
#     PointDoi())

# Get the attributions for each feature map.
attrs_bike_internal = infl_bike.attributions(x_pp).sum(axis=(2,3))

# Find the index of the top feature map.
top_feature_map_bike = int(attrs_bike_internal[0].argmax())

print('Top feature map:', top_feature_map_bike)

# Visualize the top feature map in the input space.
masked_image = ChannelMaskVisualizer(
    model,
    'features_28',
    top_feature_map_bike,
    blur=10, 
    threshold=0.95)(x, x_pp)

plt.axis('off')
plt.imshow(masked_image[0].transpose((1,2,0)))