In [1]:
from nnsight import NNsight
import torch 
from transformers import GPTJForCausalLM, AutoTokenizer 
 
model_path = "/data/lmm/checkpoints/checkpoint-1953"

gptj = GPTJForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gptj_sight = NNsight(gptj)

In [11]:
text = """| apple | banana | cherry | apple | banana |
| date | elderberry | apple | cherry | date |
| banana | cherry | apple | elderberry | banana |
| date | apple | banana | cherry | elderberry |
| cherry | date | elderberry | apple | banana |
The grid above is size 5 by 5. Each cell contains an object from ['apple', 'banana', 'cherry'].
What object is in row 0, column 0? A: apple
What object is in row 1, column 2? A: apple
What object is in row 2, column 4? A: banana
What object is in row 3, column 1? A: apple
What object is in row 4, column 3? A: apple"""

input_ids = tokenizer(text, return_tensors="pt").input_ids
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gptj.to(device)

inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(device)

with torch.no_grad():
        outputs = gptj.generate(
            inputs.input_ids,
            max_new_tokens=50,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True
        )



In [41]:
print(inputs["input_ids"])

tensor([[ 91, 282,  79,  79,  75,  68, 272, 288, 293, 293,  64, 272, 257,  71,
          68,  81,  81,  88, 272, 282,  79,  79,  75,  68, 272, 288, 293, 293,
          64, 272, 198,  91, 284, 260,  68, 272, 220, 335,  67,  68,  81,  65,
          68,  81,  81,  88, 272, 282,  79,  79,  75,  68, 272, 257,  71,  68,
          81,  81,  88, 272, 284, 260,  68, 272, 198,  91, 288, 293, 293,  64,
         272, 257,  71,  68,  81,  81,  88, 272, 282,  79,  79,  75,  68, 272,
         220, 335,  67,  68,  81,  65,  68,  81,  81,  88, 272, 288, 293, 293,
          64, 272, 198,  91, 284, 260,  68, 272, 282,  79,  79,  75,  68, 272,
         288, 293, 293,  64, 272, 257,  71,  68,  81,  81,  88, 272, 220, 335,
          67,  68,  81,  65,  68,  81,  81,  88, 272, 198,  91, 257,  71,  68,
          81,  81,  88, 272, 284, 260,  68, 272, 220, 335,  67,  68,  81,  65,
          68,  81,  81,  88, 272, 282,  79,  79,  75,  68, 272, 288, 293, 293,
          64, 272, 198, 353, 356, 359, 265, 360, 220

In [33]:
probs_layers = []

layers = gptj_sight.transformer.h

with gptj_sight.trace() as tracer:
    with tracer.invoke(outputs) as invoker:
        for layer_idx, layer in enumerate(layers):
            layer_output = gptj_sight.lm_head(gptj_sight.transformer.ln_f(layer.output[0]))

            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

probs = torch.cat([probs.value for probs in probs_layers])

max_probs, tokens = probs.max(dim=-1)

In [None]:
words = [[tokenizer.decode(t.cpu()) for t in layer_tokens] for layer_tokens in tokens]

In [46]:
input_words = [tokenizer.decode(t) for t in inputs["input_ids"]]

In [71]:
input_words[0]

"| apple | banana | cherry | apple | banana |\n| date | elderberry | apple | cherry | date |\n| banana | cherry | apple | elderberry | banana |\n| date | apple | banana | cherry | elderberry |\n| cherry | date | elderberry | apple | banana |\nThe grid above is size 5 by 5. Each cell contains an object from ['apple', 'banana', 'cherry'].\nWhat object is in row 0, column 0? A: apple\nWhat object is in row 1, column 2? A: apple\nWhat object is in row 2, column 4? A: banana\nWhat object is in row 3, column 1? A: apple\nWhat object is in row 4, column 3? A: apple"

In [76]:
print(words)

[['What', ' airplane', ' airplane', ' airplane', ' bird', ' airplane', ' bird', ' dog', ' airplane', ' airplane', 'airplane', ' airplane', ' airplane', 'airplane', 'airplane', ' dog', ' dog', 'airplane', ' bird', ' airplane', 'bird', 'bird', 'deer', 'airplane', ' bird', ' row', ' airplane', ' airplane', 'airplane', ' bird', 'What', 'What', ' from', 'airplane', '?', ' airplane', ' airplane', '?', ',', 'airplane', ',', ' row', 'airplane', ' dog', ' dog', ' row', ' airplane', ' airplane', 'bird', 'bird', 'airplane', 'airplane', ' airplane', ' airplane', ' from', 'airplane', ',', ',', ' airplane', ' airplane', ' row', 'airplane', 'airplane', ' airplane', 'What', 'What', ' dog', ' dog', ' dog', ' from', ' airplane', ' airplane', ' row', 'airplane', ',', ',', ' airplane', ' airplane', ' row', ' row', ' row', 'airplane', 'airplane', ' airplane', ' airplane', ' Each', ',', 'airplane', ',', ' row', 'deer', ' dog', ' dog', ' row', ' airplane', ' row', ' dog', ' airplane', ' from', ' airplane', '

In [61]:
print(len(words[0]))

349


IndexError: list index out of range

In [55]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected+colab+notebook"

fig = px.imshow(
    max_probs.detach().cpu().numpy(),
    x=input_words,
    y=list(range(len(words))),
    color_continuous_scale=px.colors.diverging.RdYlBu_r,
    color_continuous_midpoint=0.50,
    text_auto=True,
    labels=dict(x="Input Tokens", y="Layers", color="Probability")
)

fig.update_layout(
    title='Logit Lens Visualization',
    xaxis_tickangle=0
)

fig.update_traces(text=words, texttemplate="%{text}")
fig.show()


ValueError: The length of the x vector must match the length of the second dimension of the img matrix.