In [1]:
from typing import Any, Iterable, List, Tuple
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import numpy as np

In [2]:
import sys
sys.path.append("../src")
from explainer import Archipelago
from application_utils.text_utils import TextXformer
from application_utils.text_utils_torch import BertWrapperTorch
from viz.text import viz_text

In [3]:
GPT2_PATH = "gpt_model"


def download_models() -> None:
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained('gpt2-large')
    tokenizer.save_pretrained(GPT2_PATH)
    model.save_pretrained(GPT2_PATH)


def load_model_from_file() -> Tuple[Any, Any]:
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt_model')
    model = GPT2LMHeadModel.from_pretrained('gpt_model')
    return tokenizer, model

In [4]:
# Do an example

# download_models()
tokenizer, model = load_model_from_file()
sent = "yesterday afternoon me and my son went to the"

tokens = tokenizer(sent).input_ids  # (sent_len,)
preds = model(torch.LongTensor(tokens)).logits  # (sent_len, vocab_size)
next_word_preds = preds[-1]
next_word_tokens = torch.topk(next_word_preds, k=10).indices
next_word_logits = torch.topk(next_word_preds, k=10).values
next_words = tokenizer.decode(next_word_tokens)
print(next_words)

 store park local grocery beach mall gym library hospital school


In [5]:
next_word_preds.shape

torch.Size([50257])

In [6]:
tokens, next_word_tokens

([8505, 6432, 6672, 502, 290, 616, 3367, 1816, 284, 262],
 tensor([ 3650,  3952,  1957, 16918, 10481, 17374, 11550,  5888,  4436,  1524]))

In [7]:
most_likely_next_token = 3650

In [8]:
tokenizer('_').input_ids

[62]

In [9]:
baseline_token = tokenizer('_').input_ids
baseline_token * 10

[62, 62, 62, 62, 62, 62, 62, 62, 62, 62]

In [10]:
class GPTWrapperTorch:
    def __init__(self, model: Any, device: str, merge_logits: bool = False) -> None:
        self.model = model

    def __call__(self, batch_ids: Iterable[Iterable[int]]) -> List[List[float]]:
        """
        Input: A batch of examples, where each example is a list of tokens
        - shape = (batch_size, sent_len)
        Output: For each example, the logits for the likelihood of for all 50257 tokens to be the next word
        - shape = (batch_size, 50257)
        """
        preds = model(torch.LongTensor(batch_ids)).logits  # (batch_size, sent_len, vocab_size)
        next_word_preds = preds[:, -1]  # (batch_size, vocab_size)
        return next_word_preds.detach().numpy()

In [11]:
text_ids, baseline_ids = tokens, baseline_token * len(tokens)
# output_indices = next_word_tokens.detach().numpy()
output_indices = most_likely_next_token

In [12]:
text_ids, baseline_ids, output_indices

([8505, 6432, 6672, 502, 290, 616, 3367, 1816, 284, 262],
 [62, 62, 62, 62, 62, 62, 62, 62, 62, 62],
 3650)

In [13]:
model_wrapper = GPTWrapperTorch(model=model, device="cpu")

In [14]:
xf = TextXformer(text_ids, baseline_ids)

In [15]:
apgo = Archipelago(
    model_wrapper, 
    data_xformer=xf, 
    output_indices=output_indices,
    batch_size=1,
)

In [16]:
explanation = apgo.explain(top_k=8)

index_tuple (0,)
preds [[ 2.369074   4.3052573  4.854746  ... -4.128131  -5.1946387  0.5706473]
 [ 2.5863144  4.1248903  3.4009182 ... -4.5984774 -5.177022   2.6482196]]
c 0
self.output_indices 3650


Exception: {(0,): -3.0886092}