In [None]:
import sys
sys.path.insert(0,'/home/rick/transformers/src')
import os
from pathlib import Path

trulens_path = (Path(os.getcwd()).parent.parent / "trulens")

sys.path.insert(0, str(trulens_path))
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import BloomForCausalLM
from transformers import BloomTokenizerFast
import torch


model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1")
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")


In [None]:
device="cuda:0"
model.to(device)
tok=tokenizer

In [None]:
from typing import Iterable, Union, Optional
from truera.nlp.general.aiq.nlp_coloring import attributions_to_rgb
from truera.nlp.general.aiq.nlp_coloring import generate_rgb_str
from truera.nlp.general.aiq.nlp_coloring import MAX_INTENSITY
from truera.nlp.general.aiq.nlp_coloring import rgb_str
from IPython.display import HTML
# copied from truera
def _influence_examples(
        tokens_list: Iterable[Iterable[str]],
        attributions_list: Iterable[Iterable[float]],
        *,
        qoi_class: Union[int, str] = 0,
        underline_list: Optional[Union[Iterable[Iterable[int]],
                                       Iterable[int]]] = None,
        prepends: Union[Iterable[str], str] = ''
    ) -> HTML:
        """
        plot the tokens & their attributions for list of token and attributions one by one
        underline_list specify the token index to underline (used for plotting influence of a sentence
        containing a specific token)
        """
        if len(attributions_list) == 0:
            return
        norm_factor = np.max(
            [
                np.max(np.abs(attributions))
                for attributions in attributions_list
            ]
        )
        # Display legend
        neg_infl_color = rgb_str(256, 256 - MAX_INTENSITY, 256 - MAX_INTENSITY)
        neutral_color = rgb_str(256, 256, 256)
        pos_infl_color = rgb_str(256 - MAX_INTENSITY, 256, 256 - MAX_INTENSITY)
        if isinstance(qoi_class, int):
            qoi_class = f"Class: {qoi_class}"
        qoi_class = qoi_class.replace('_', ' ')
        qoi_class = qoi_class.title()
        html_str = [
            f'''
            <div style="margin:auto; width:50%; height:20px; display:flex; align-items:center;justify-content: space-between; background-image:linear-gradient(to right, {neg_infl_color}, {neutral_color}, {pos_infl_color});">
                <strong style=margin-left:4px>Negative Influence</strong>
                <strong style=text-align:center>{qoi_class}</strong>
                <strong style=margin-right:4px>Postive Influence</strong>
            </div>
            '''
        ]

        # Plot examples
        if isinstance(prepends, str):
            prepends = [prepends for _ in range(len(tokens_list))]
        for si, (attributions, tokens, prepend) in enumerate(
            zip(attributions_list, tokens_list, prepends)
        ):
            underline_idxs = [
                underline_list[si]
            ] if underline_list is not None else None
            if isinstance(underline_idxs, int):
                underline_idxs = [underline_idxs]
            line_html = generate_rgb_str(
                tokens,
                attributions,
                underline_idxs,
                norm_factor=norm_factor,
                max_intensity=MAX_INTENSITY
            )
            html = f'<p style=padding-bottom:2px><h4 style=margin:0;>{prepend}</h4> {line_html}'
            html_str.append(html)

        return HTML("\n".join(html_str))

In [None]:
from trulens.nn.models import get_model_wrapper
from trulens.nn.distributions import PointDoi
from trulens.nn.quantities import LambdaQoI
from trulens.nn.attribution import IntegratedGradients, Saliency
from trulens.nn.attribution import Cut, OutputCut,InputCut
from trulens.utils.typing import ModelInputs
import numpy as np
def explain(inputs, model_inputs, reply_ids, chosen_token):
    #chosen_token = 298
    wrapper = get_model_wrapper(model)
    qoi = LambdaQoI(lambda out: out[-1][-1][chosen_token])
    infl = IntegratedGradients(
            model=wrapper,
            doi_cut=Cut('transformer_word_embeddings'),
            qoi_cut=OutputCut(accessor=lambda o: o['logits']),
            qoi=qoi,
            resolution=128,
            rebatch_size=4)
    infl = infl.attributions(**model_inputs)
    influence_sums = np.sum(infl.attributions[0][0][0],axis=-1)

    input_tokens = tokenizer.batch_decode(inputs['input_ids'][0])
    display(_influence_examples([input_tokens], [influence_sums], prepends=f"next token:{tokenizer.decode(chosen_token)}"))

In [None]:
import copy
def explain_utterance(utterance, reply_ids):
    #for i in range(len(reply_ids)-1):
    #    next_token = i+1
    current_utterance = utterance
    inputs = tokenizer([utterance], return_tensors="pt").to(device)
    for next_token_in_reply in range(len(inputs['input_ids'][0]), len(reply_ids[0])):
        chosen_token = int(reply_ids[0][next_token_in_reply])
        explain(inputs=inputs, model_inputs=inputs, reply_ids=reply_ids, chosen_token=chosen_token)
        current_utterance = current_utterance + tokenizer.decode(chosen_token)
        inputs = tokenizer([current_utterance], return_tensors="pt").to(device)



# Model Context
Bloom is a document completion model that iteratively predicts the next token

In [None]:
UTTERANCE = "Complete the names of these NBA players: Michael"
inputs = tokenizer([UTTERANCE], return_tensors="pt").to(device)
reply_ids = model.generate(**inputs,max_length=50, 
                       no_repeat_ngram_size=2,
                       early_stopping=True)
print(tokenizer.batch_decode(reply_ids))

In [None]:
UTTERANCE = "Complete the names of these NBA players: Michael Jordan, Kobe"
inputs = tokenizer([UTTERANCE], return_tensors="pt").to(device)
reply_ids = model.generate(**inputs,max_length=50,
                       no_repeat_ngram_size=2,
                       early_stopping=True)
print(tokenizer.batch_decode(reply_ids))

In [None]:
UTTERANCE = "Complete the names of these NBA players: Michael Jordan, Kobe Bryant. What about soccer players?"
inputs = tokenizer([UTTERANCE], return_tensors="pt").to(device)
reply_ids = model.generate(**inputs,max_length=100,
                       no_repeat_ngram_size=2,
                       early_stopping=True)
print(tokenizer.batch_decode(reply_ids))

# Explanations

In [None]:
UTTERANCE = "Complete the names of these NBA players: Michael"
inputs = tokenizer([UTTERANCE], return_tensors="pt").to(device)
reply_ids = model.generate(**inputs,max_length=200, 
                       num_beams=2, 
                       no_repeat_ngram_size=2,
                       early_stopping=True)
print([f"{tok_id}: {tokenizer.decode(tok_id)}" for tok_id in reply_ids[0]])
print(tokenizer.batch_decode(reply_ids))
#print(reply_ids)
explain_utterance(UTTERANCE, reply_ids)