# Attributions

## Model Wrappers

In order to support a wide variety of backends with different interfaces for their respective models, TruLens uses its own `ModelWrapper` class which provides a general model interface to simplify the implementation of the API functions.
To get the model wrapper, use the `get_model_wrapper` method in `trulens.nn.models`. A model wrapper class exists for each backend that converts a model in the respective backend's format to the general TruLens `ModelWrapper` interface. The wrappers are found in the `models` module, and any model defined using Keras, Pytorch, or Tensorflow should be wrapped before being used with the other API functions that require a model -- all other TruLens functionalities expect models to be an instance of `trulens.nn.models.ModelWrapper`.

For more details on allowed parameters, see the [get_model_wrapper](https://truera.github.io/trulens/api/model_wrappers/) documentation.

For this demo, we will be using a Pytorch model pre-trained on Imagenet. 

In [None]:
import os
import sys
import torch
import numpy as np

from trulens.nn.models import get_model_wrapper

import numpy as np
import matplotlib.pyplot as plt
import PIL

%matplotlib inline

In [None]:
from trulens.nn.attribution import InputAttribution, InternalInfluence
from trulens.nn.attribution import IntegratedGradients
from trulens.nn.attribution import Cut, InputCut, OutputCut
from trulens.nn.distributions import LinearDoi, PointDoi

from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification

device = torch.device("cuda:0")

bert = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased").to(device)

bt = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
sentences = ["this is a sentence", "and so is this but this one is longer"]
inputs = bt(sentences, padding=True, return_tensors="pt")

inputs['input_ids'] = inputs['input_ids'].to(device)
inputs['attention_mask'] = inputs['attention_mask'].to(device)

def as_tensor(thing):
    if isinstance(thing, torch.Tensor):
        return torch.clone(thing)
    else:
        return torch.Tensor(thing).type_as(thing.dtype).device(thing.device)

def token_baseline(inputs, replace_map):
    input_ids = as_tensor(inputs['input_ids']).to(device)
    masks = as_tensor(inputs['attention_mask']).to(device)
    for token, (replacement_id, replacement_mask) in replace_map.items():
        ids = input_ids == token
        input_ids[ids] = replacement_id
        masks[ids] = replacement_mask

    return dict(input_ids=input_ids, attention_mask=masks)

inputs_baseline = token_baseline(inputs, {101: (0, 0), 102: (0,0)})

# outs = bert(**inputs)

model = get_model_wrapper(bert, input_shape=(None, bt.model_max_length))
model.print_layer_names()


In [None]:
infl = IntegratedGradients(model, resolution=10)
attrs_input = infl.attributions(**inputs)

In [None]:
import os
from pathlib import Path
import random
import pandas as pd
import json

from typing import List

from pathlib import Path
import torch
import torch.nn as nn
from torchtext.data import get_tokenizer

class ToySentiment(nn.Module):
    @staticmethod
    def from_pretrained(model_code_path: Path) -> 'ToySentiment':
        # Not strictly necessary to load or save given fixed parameters can be populated, but 
        # implementing this for better integration testing.
        model = ToySentiment()
        model.load_state_dict(torch.load(model_code_path))

    def to_pretrained(self, model_code_path: Path) -> None:
        torch.save(self.state_dict(), model_code_path)

    def set_parameters(self) -> None:
        """Set model parameters as per fixed specification."""

        Wi = torch.zeros_like(self.lstm.weight_ih_l0) + 0.1
        bi = torch.zeros_like(self.lstm.bias_ih_l0) + 0.1
        Wh = torch.zeros_like(self.lstm.weight_hh_l0) + 0.1
        bh = torch.zeros_like(self.lstm.bias_hh_l0) + 0.1

        big = 20.0 # Multipliers to help dealing with LSTM sigmoids.
        half = 8.0 # Intention here is that sigmoid((x*big) - half) is ~0 if x is ~0; and
                    # ~1 when indicator is >~ 1.

        hs = self.hidden_size

        sneutral = 0
        sgood = 1
        sbad = 2
        sconfused = 3

        wneutral = self.vocab['neutral']
        wgood = self.vocab['good']
        wbad = self.vocab['bad']
        
        # make sure c gate is always big
        bi[0:hs*3] = 100.0
        bh[0:hs*3] = 100.0

        # o gate weights:
        Wi[3*hs,   wneutral] = big # read neutral word
        Wi[3*hs+1, wgood] = big # read good word
        Wi[3*hs+2, wbad] = big # read bad word
        Wh[3*hs,   sneutral] = big # keep prior neutral, good, bad states
        Wh[3*hs+1, sgood] = big # 
        Wh[3*hs+2, sbad] = big #
        bi[3*hs:4*hs] = -half # sigmoid will be 0 unless one of the three words was read

        # set "good to bad" confused if prior was good, and input was bad
        Wh[3*hs+3, sgood] = big    # (prior state was good
        Wi[3*hs+3, wbad] = big    #  and input was bad)
        Wh[3*hs+3, sconfused] = 2*big  # or (was already in this confused state)
        bh[3*hs+3] = -(half*2) # Want at least 2 of first two to fire, or just the last one to fire.

        # set "bad to good" confused if prior was bad, and input was good
        Wh[3*hs+4, sbad] = big     # (prior state was bad
        Wi[3*hs+4, wgood] = big     #  and input was good)
        Wh[3*hs+4, sconfused] = 2*big   # or (was already confused)
        bh[3*hs+4] = -(half*2)  #

        self.lstm.weight_hh_l0 = nn.Parameter(Wh)
        self.lstm.bias_hh_l0 = nn.Parameter(bh)
        self.lstm.weight_ih_l0 = nn.Parameter(Wi)
        self.lstm.bias_ih_l0 = nn.Parameter(bi)

        self.embedding.weight = nn.Parameter(torch.eye(self.emb_size))

        self.lin.weight = nn.Parameter(torch.tensor(
            [
                [10.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 20.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 20.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 30.0, 30.0]
            ]
        ))

    def __init__(self):
        super().__init__()
        
        self.tokenizer = get_tokenizer("basic_english")

        self.vocab = {"[UNKNOWN]": 0, "neutral": 1, "good": 2, "bad": 3}

        self.emb_size = len(self.vocab)

        self.hidden_size = 5
        # 5 states, one for neutral, one for positive, one for negative, and two for confused. Requiring two
        # confused states for simplicity of the model; it is easier to encode semantics of confusion based on 
        # which initial positive/negative state was set first.

        # Identity embedding, each vocab word has its own dimension where its presence is encoded.
        self.embedding = nn.Embedding(
            # padding_idx=0, 
            embedding_dim=self.emb_size, 
            num_embeddings=self.emb_size,
            padding_idx=None,
            max_norm=None,
            norm_type=None
        )

        self.lstm = nn.LSTM(
            input_size=self.emb_size,
            hidden_size=self.hidden_size,
            num_layers=1,
            batch_first=True
        )

        # Linear layer to combine the two types of confused state and weight things so that
        # confused outweighs positive and negative, while positive and negative outweigh neutral 
        # if more than one of these states is set.
        self.lin = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=4,
            bias=False
        )

        # self.sigmoid = torch.nn.Sigmoid()

        # Finally add a softmax for classification.
        self.softmax = torch.nn.Softmax(dim=1)
        # self.softmax = torch.nn.LogSoftmax(dim=1)

        self.set_parameters()

    def forward(self, word_ids: torch.Tensor, embeds=None) -> torch.Tensor:
        if word_ids is not None:
            batch_size = word_ids.shape[0]
        else:
            batch_size = embeds.shape[0]

        h0 = torch.zeros(1, batch_size, self.hidden_size)
        h0[0, 0, 0] = 1.0 # initial state is neutral
        c0 = torch.zeros(1, batch_size, self.hidden_size)

        if embeds is None:
            embeds = self.embedding(word_ids)
            
        embeds.retain_grad()

        out, (hn, cn) = self.lstm(embeds, (h0, c0))
        hn = hn[0,:,:]

        lin_out = self.lin(hn)

        probits = self.softmax(lin_out)
        # preds = torch.argmax(probits, axis=1)
        # pred_prob = torch.gather(probits, dim=1, index=preds[:,None])
        # pred_prob.backward(torch.ones_like(pred_prob))
        # attr = embeds.grad
        # return preds, probits, attr

        return probits


    def input_of_text(self, texts: List[str]) -> torch.Tensor:
        tokens = [self.tokenizer(text) for text in texts]

        word_ids = [[(self.vocab[t] if t in self.vocab else 0) for t in token] for token in tokens]

        return torch.Tensor(word_ids).int()


def generate_dataset(n, l):
    """Generate random sentiment sentences and their labels."""

    ret = []
    cls = []
    for i in range(n):
        sent = []
        while len(sent) < l:
            r = random.random()

            word = "neutral"

            if r > 0.9 and len(sent) > 0:
                continue
            elif r > 0.8:
                word = "good"
            elif r > 0.7:
                word = "bad"

            sent.append(word)

        ret.append(" ".join(sent))

        gt = 0 # neutral
        if "good" in sent and "bad" in sent:
            gt = 3 # confused
        elif "good" in sent:
            gt = 1 # positive
        elif "bad" in sent:
            gt = 2 # negative

        cls.append(gt)

    return pd.DataFrame(dict(sentence=ret, sentiment=cls))

In [None]:
pt = ToySentiment()
device = 'cpu'
# Produce a wrapped model from the pytorch model.
model = get_model_wrapper(pt, input_shape=(4), device=device, input_dtype=int, )

# test_data = generate_dataset(5, 4)

test_data = generate_dataset(16, 4)

x_sentences = test_data['sentence']
x_tokens = [pt.tokenizer(s) for s in x_sentences]
x_ids = pt.input_of_text(x_sentences)
labels = test_data['sentiment']

# TODO: define "EmbeddedInfluence" for InputInfluence except with an initial embedding layer and initialize internal influence as below:
from trulens.nn.quantities import LambdaQoI, ClassQoI
from IPython.display import HTML, display
import tabulate

cls_names = {0:"neutral", 1:"positive", 2:"negative", 3:"confused"}

infl = {}
attrs = {}
for cls in [0,1,2,3]:
    infl[cls] = InternalInfluence(
        model,
        cuts=(Cut("embedding"), Cut("softmax")),
        #doi=PointDoi(cut=Cut("embedding")),
        doi=LinearDoi(cut=Cut("embedding"), resolution=100),
        qoi=ClassQoI(cls),
        multiply_activation=False
    )
    attrs[cls] = infl[cls].attributions(x_ids)


preds = pt(x_ids)

data = []

for i, (words, label, pred) in enumerate(zip(x_tokens, labels, preds)):
    data.append([f"<b>GT={cls_names[label]} PRED={cls_names[pred.argmax().item()]}</b>"] + [" ".join(words)])

    for cls in [0,1,2,3]:
        # print(f"{cls_names[cls]} | ", end='')

        row = [cls_names[cls]]

        sent = ""

        for word, attr in zip(words, attrs[cls][i]):
            # print(f"{word}[{attr.sum()}] ", end="")
            mag = attr.sum()
            red = 0.0
            green = 0.0
            if mag > 0:
                green = 1.0 # 0.5 + mag * 0.5
                red = 1.0 - mag * 0.5
            else:
                red = 1.0
                green = 1.0 + mag * 0.5
                #red = 0.5 - mag * 0.5

            blue = min(red, green)
            # blue = 1.0 - max(red, green)

            sent += f"<span style='color: rgb({red*255}, {green*255}, {blue*255});'>{word}</span> "

        row.append(sent)
        data.append(row)

        # print()

tab = tabulate.tabulate(data, tablefmt='html')
display(HTML(tab))

In [None]:
test_data['sentence'][0] = "neutral neutral neutral neutral"
test_data['sentence'][1] = "good neutral neutral neutral"
test_data['sentence'][2] = "neutral good neutral neutral"
test_data['sentence'][3] = "neutral neutral good neutral"
test_data['sentence'][4] = "neutral neutral neutral good"
test_data['sentiment'][0] = 0
test_data['sentiment'][1:] = 1

test_data

In [None]:
test_data = generate_dataset(4, 2)
pm = ToySentiment()
pm.requires_grad_(True)
pm.train()

word_ids = pm.input_of_text(test_data['sentence'])
# embeds = pm.embedding(word_ids)
# embeds.retain_grad()

probits = pm(word_ids=word_ids)

print(probits)