[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tanikina/mi-tutorials/blob/main/notebooks/LogitLens.ipynb)

Following mostly [this tutorial](https://nnsight.net/notebooks/tutorials/logit_lens/) but using a different (smaller) model.

In [None]:
!pip install nnsight
!pip install plotly

In [1]:
from IPython.display import clear_output
from nnsight import LanguageModel
from typing import List, Callable
import torch
import numpy as np
from IPython.display import clear_output

clear_output()

In [2]:
# Load gpt2
model = LanguageModel("ComCom/gpt2-small", device_map="auto")#, dispatch=True)

In [3]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)


To apply logit lens, we collect activations at each layer’s output, apply layer normalization (model.transformer.ln_f), and then process through the model’s head (model.lm_head) to get the logits. Next, we apply the softmax to the logits to obtain output token probabilities.

By observing different layers’ output token probabilities, logit lens provides insights into the model’s confidence throughout processing steps.

In [4]:
prompt= "The Eiffel Tower is in the city of"
layers = model.transformer.h
probs_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        for layer_idx, layer in enumerate(layers):
            # Process layer output through the model's head and layer normalization
            layer_output = model.lm_head(model.transformer.ln_f(layer.output[0]))

            # Apply softmax to obtain probabilities and save the result
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

probs = torch.cat([probs.value for probs in probs_layers])

# Find the maximum probability and corresponding tokens for each position
max_probs, tokens = probs.max(dim=-1)

# Decode token IDs to words for each layer
words = [[model.tokenizer.decode(t.cpu()).encode("unicode_escape").decode() for t in layer_tokens]
    for layer_tokens in tokens]

# Access the 'input_ids' attribute of the invoker object to get the input words
input_words = [model.tokenizer.decode(t) for t in invoker.inputs[0][0]["input_ids"][0]]

Now we will visualize the prediction of the GPT-2 model while processing the string ``’The Eiffel Tower is in the city of’`` and we’ll explore the interpretations of each layer within the GPT2Block, gaining insights into what each layer believes could be the next word for every input word.

In [5]:
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "colab"

fig = px.imshow(
    max_probs.detach().cpu().numpy(),
    x=input_words,
    y=list(range(len(words))),
    color_continuous_scale=px.colors.diverging.RdYlBu_r,
    color_continuous_midpoint=0.50,
    text_auto=True,
    labels=dict(x="Input Tokens", y="Layers", color="Probability")
)

fig.update_layout(
    title='Logit Lens Visualization',
    xaxis_tickangle=0
)

fig.update_traces(text=words, texttemplate="%{text}")
fig.show()

Note that you can also try logit lens visualization using a different library: `tuned-lens`. See the official tutorial [here](https://colab.research.google.com/github/AlignmentResearch/tuned-lens/blob/main/notebooks/interactive.ipynb). Unfortunately, tuned-lens does not support the latest version of torch, thus we will skip it in the current tutorial.