<a href="https://colab.research.google.com/github/tatsath/Interpretability/blob/main/logit_lens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Logit Lens

## Introduction

🔍 Logit Lens is a powerful tool that grants us a simplified (yet insightful) understanding of the inner workings of transformer models.

We can estimate the model's guess for the output after each computational step by applying a softmax function to each layer's output. Unlike traditional approaches focusing on *how* beliefs are updated within a step, with Logit Lens we gain a glimpse into *what* output the model is predicting at each processing step.

📗 Read more about Logit Lens from nostalgebraist’s blog post on LessWrong, [here](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)

💻 You can find a Colab version of our tutorial [here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/docs/source/notebooks/tutorials/logit_lens.ipynb), or nostalgebraist’s original code [here](https://colab.research.google.com/drive/1-nOE-Qyia3ElM17qrdoHAtGmLCPUZijg?usp=sharing)

## Setup

If using Colab, install NNsight:
```
!pip install -U nnsight
```

In [None]:
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight

Collecting nnsight
  Downloading nnsight-0.4.3-py3-none-any.whl.metadata (15 kB)
Collecting python-socketio[client] (from nnsight)
  Downloading python_socketio-5.12.1-py3-none-any.whl.metadata (3.2 kB)
Collecting msgspec (from nnsight)
  Downloading msgspec-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.4.0->nnsight)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.4.0->nnsight)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.4.0->nnsight)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.4.0->nnsight)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.w

Import libraries and load GPT-2 model.

In [None]:
# Import libraries
from IPython.display import clear_output
from nnsight import LanguageModel
from typing import List, Callable
import torch
import numpy as np
from IPython.display import clear_output

clear_output()

In [None]:
# Load gpt2
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

## GPT-2 Model Architecture

Let's take a look at GPT-2's architecture. GPT-2 has 12 layers, accessed as `model.transformer.h`.

In [None]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)


## Apply Logit Lens



To apply logit lens, we collect activations at each layer's output, apply layer normalization (`model.transformer.ln_f`), and then process through the model's head (`model.lm_head`) to get the logits. Next, we apply the softmax to the logits to obtain output token probabilities.

By observing different layers' output token probabilities, logit lens provides insights into the model's confidence throughout processing steps.

In [None]:
#prompt= "The Eiffel Tower is in the city of"
prompt= "The development of artificial intelligence is transforming"
layers = model.transformer.h
probs_layers = []
probs_attn_layers = []
probs_mlp_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        for layer_idx, layer in enumerate(layers):
            # Process layer output through the model's head and layer normalization
            layer_output = model.lm_head(model.transformer.ln_f(layer.output[0]))

            # Apply softmax to obtain probabilities and save the result
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

            # Get attention output and process
            attn_output = model.lm_head(model.transformer.ln_f(layer.attn.output[0]))
            probs_attn = torch.nn.functional.softmax(attn_output, dim=-1).save()
            probs_attn_layers.append(probs_attn)

             # Get mlp output and process
            mlp_output = model.lm_head(model.transformer.ln_f(layer.mlp.output[0]))
            probs_mlp = torch.nn.functional.softmax(mlp_output, dim=-1).save()
            probs_mlp_layers.append(probs_mlp)

probs = torch.cat([probs.value for probs in probs_layers])
probs_attn = torch.cat([probs.value for probs in probs_attn_layers])
probs_mlp = torch.cat([probs.value for probs in probs_mlp_layers])

# Find the maximum probability and corresponding tokens for each position
max_probs, tokens = probs.max(dim=-1)

# Decode token IDs to words for each layer
words = [[model.tokenizer.decode(t.cpu()).encode("unicode_escape").decode() for t in layer_tokens]
    for layer_tokens in tokens]

# Access the 'input_ids' attribute of the invoker object to get the input words
input_words = [model.tokenizer.decode(t) for t in invoker.inputs[0][0]["input_ids"][0]]

In [None]:
max_probs

tensor([[0.0250, 0.9599, 0.7711, 0.9886, 0.9404, 0.2956, 0.8330],
        [0.0475, 0.9310, 0.8037, 0.9710, 0.4738, 0.3557, 0.8909],
        [0.0261, 0.6634, 0.7631, 0.9512, 0.2080, 0.2755, 0.8197],
        [0.0264, 0.5483, 0.4849, 0.9454, 0.3511, 0.4051, 0.8269],
        [0.0261, 0.5325, 0.5001, 0.8414, 0.5714, 0.3141, 0.4073],
        [0.0257, 0.3901, 0.5443, 0.8381, 0.7562, 0.3330, 0.2830],
        [0.0254, 0.3366, 0.3511, 0.4094, 0.7762, 0.3024, 0.2383],
        [0.0253, 0.4538, 0.4067, 0.4395, 0.7568, 0.1571, 0.2547],
        [0.0253, 0.5116, 0.4044, 0.5962, 0.8444, 0.1531, 0.2256],
        [0.0253, 0.6232, 0.3969, 0.8155, 0.7149, 0.0891, 0.2765],
        [0.0255, 0.8269, 0.3530, 0.8848, 0.4488, 0.0465, 0.2403],
        [0.0136, 0.4350, 0.1499, 0.8029, 0.1830, 0.1099, 0.3741]],
       grad_fn=<MaxBackward0>)

## Visualizing GPT-2 Layer Interpretations

Now we will visualize the prediction of the GPT-2 model while processing the string *`'The Eiffel Tower is in the city of'`* and we’ll explore the interpretations of each layer within the GPT2Block, gaining insights into what each layer believes could be the next word for every input word.



In [None]:
import plotly.express as px
import plotly.io as pio

if is_colab:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

fig = px.imshow(
    max_probs.detach().cpu().numpy(),
    x=input_words,
    y=list(range(len(words))),
    color_continuous_scale=px.colors.diverging.RdYlBu_r,
    color_continuous_midpoint=0.50,
    text_auto=True,
    labels=dict(x="Input Tokens", y="Layers", color="Probability")
)

fig.update_layout(
    title='Logit Lens Visualization',
    xaxis_tickangle=0
)

fig.update_traces(text=words, texttemplate="%{text}")
fig.show()

The vertical axis indexes the layers, zero-indexed from 0 to 11. The top guess for each token, according to the model’s activations at a given layer, is printed in each cell. The colors show the probability associated with the top guess.

In [1]:
# import plotly.express as px
# import plotly.io as pio
# import torch
# from plotly.subplots import make_subplots
# import plotly.graph_objects as go

# if is_colab:
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

# num_layers = len(probs_layers)
# num_input_words = len(input_words)

# fig = make_subplots(rows=num_layers, cols=1, subplot_titles=[f"Layer {i}" for i in range(num_layers)])


# for layer_idx, layer_probs in enumerate(probs_layers):
#     layer_probs = layer_probs.value.detach().cpu().numpy() # Use .value to get the tensor's value

#     # Get the token ids
#     all_tokens = list(range(layer_probs.shape[-1]))

#     # Reshape the layer probabilities to have input words as rows and token ids as columns.
#     layer_probs_reshaped = layer_probs.transpose(1,0,2).reshape(num_input_words, -1)


#     fig.add_trace(go.Heatmap(z=layer_probs_reshaped,
#                             x=[model.tokenizer.decode(t).encode("unicode_escape").decode() for t in all_tokens],
#                             y=input_words,
#                             colorscale=px.colors.diverging.RdYlBu_r,
#                             zmid=0.5,
#                             ),
#                   row=layer_idx+1, col=1)

# fig.update_layout(height=300*num_layers, title_text="Logit Lens Visualization - All Tokens per Layer",
#                    xaxis_tickangle=0)
# fig.update_xaxes(tickangle=45)
# fig.show()