<a href="https://colab.research.google.com/github/tatsath/Interpretability/blob/main/logit_lens_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Logit Lens

## Introduction

🔍 Logit Lens is a powerful tool that grants us a simplified (yet insightful) understanding of the inner workings of transformer models.

We can estimate the model's guess for the output after each computational step by applying a softmax function to each layer's output. Unlike traditional approaches focusing on *how* beliefs are updated within a step, with Logit Lens we gain a glimpse into *what* output the model is predicting at each processing step.

📗 Read more about Logit Lens from nostalgebraist’s blog post on LessWrong, [here](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)

💻 You can find a Colab version of our tutorial [here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/docs/source/notebooks/tutorials/logit_lens.ipynb), or nostalgebraist’s original code [here](https://colab.research.google.com/drive/1-nOE-Qyia3ElM17qrdoHAtGmLCPUZijg?usp=sharing)

## Setup

If using Colab, install NNsight:
```
!pip install -U nnsight
```

In [62]:
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight



Import libraries and load GPT-2 model.

In [63]:
# Import libraries
from IPython.display import clear_output
from nnsight import LanguageModel
from typing import List, Callable
import torch
import numpy as np
from IPython.display import clear_output

clear_output()

In [64]:
# Load gpt2
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

## GPT-2 Model Architecture

Let's take a look at GPT-2's architecture. GPT-2 has 12 layers, accessed as `model.transformer.h`.

In [65]:
#print(model)

## Apply Logit Lens



To apply logit lens, we collect activations at each layer's output, apply layer normalization (`model.transformer.ln_f`), and then process through the model's head (`model.lm_head`) to get the logits. Next, we apply the softmax to the logits to obtain output token probabilities.

By observing different layers' output token probabilities, logit lens provides insights into the model's confidence throughout processing steps.

In [72]:
#prompt= "The Eiffel Tower is in the city of"
prompt= "With good earnings the stock price of company will likely"
#prompt= "With the Federal Reserve raising interest rates, the bond market is expected to"
layers = model.transformer.h
probs_layers = []
probs_attn_layers = []
probs_mlp_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        for layer_idx, layer in enumerate(layers):
            # Process layer output through the model's head and layer normalization
            layer_output = model.lm_head(model.transformer.ln_f(layer.output[0]))

            # Apply softmax to obtain probabilities and save the result
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

            # Get attention output and process
            attn_output = model.lm_head(model.transformer.ln_f(layer.attn.output[0]))
            probs_attn = torch.nn.functional.softmax(attn_output, dim=-1).save()
            probs_attn_layers.append(probs_attn)

             # Get mlp output and process
            mlp_output = model.lm_head(model.transformer.ln_f(layer.mlp.output[0]))
            probs_mlp = torch.nn.functional.softmax(mlp_output, dim=-1).save()
            probs_mlp_layers.append(probs_mlp)

probs = torch.cat([probs.value for probs in probs_layers])
probs_attn = torch.cat([probs.value for probs in probs_attn_layers])
probs_mlp = torch.cat([probs.value for probs in probs_mlp_layers])

# Find the maximum probability and corresponding tokens for each position
max_probs, tokens = probs.max(dim=-1)

# Decode token IDs to words for each layer
words = [[model.tokenizer.decode(t.cpu()).encode("unicode_escape").decode() for t in layer_tokens]
    for layer_tokens in tokens]

# Access the 'input_ids' attribute of the invoker object to get the input words
input_words = [model.tokenizer.decode(t) for t in invoker.inputs[0][0]["input_ids"][0]]

In [73]:
# probs_layers

In [74]:
# probs_attn

## Visualizing GPT-2 Layer Interpretations

Now we will visualize the prediction of the GPT-2 model and we’ll explore the interpretations of each layer within the GPT2Block, gaining insights into what each layer believes could be the next word for every input word.



In [75]:
import plotly.express as px
import plotly.io as pio

if is_colab:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

fig = px.imshow(
    max_probs.detach().cpu().numpy(),
    x=input_words,
    y=list(range(len(words))),
    color_continuous_scale=px.colors.diverging.RdYlBu_r,
    color_continuous_midpoint=0.50,
    text_auto=True,
    labels=dict(x="Input Tokens", y="Layers", color="Probability")
)

fig.update_layout(
    title='Logit Lens Visualization',
    xaxis_tickangle=0
)

fig.update_traces(text=words, texttemplate="%{text}")
fig.show()

The vertical axis indexes the layers, zero-indexed from 0 to 11. The top guess for each token, according to the model’s activations at a given layer, is printed in each cell. The colors show the probability associated with the top guess.

In [76]:
# import plotly.express as px
# import plotly.io as pio
# import torch
# from plotly.subplots import make_subplots
# import plotly.graph_objects as go

# if is_colab:
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

# num_layers = len(probs_layers)
# num_input_words = len(input_words)

# fig = make_subplots(rows=num_layers, cols=1, subplot_titles=[f"Layer {i}" for i in range(num_layers)])


# for layer_idx, layer_probs in enumerate(probs_layers):
#     layer_probs = layer_probs.value.detach().cpu().numpy() # Use .value to get the tensor's value

#     # Get the token ids
#     all_tokens = list(range(layer_probs.shape[-1]))

#     # Reshape the layer probabilities to have input words as rows and token ids as columns.
#     layer_probs_reshaped = layer_probs.transpose(1,0,2).reshape(num_input_words, -1)


#     fig.add_trace(go.Heatmap(z=layer_probs_reshaped,
#                             x=[model.tokenizer.decode(t).encode("unicode_escape").decode() for t in all_tokens],
#                             y=input_words,
#                             colorscale=px.colors.diverging.RdYlBu_r,
#                             zmid=0.5,
#                             ),
#                   row=layer_idx+1, col=1)

# fig.update_layout(height=300*num_layers, title_text="Logit Lens Visualization - All Tokens per Layer",
#                    xaxis_tickangle=0)
# fig.update_xaxes(tickangle=45)
# fig.show()

In [83]:
# Import libraries
from IPython.display import clear_output
from nnsight import LanguageModel
from typing import List, Callable
import torch
import numpy as np
from IPython.display import clear_output

clear_output()
# %%
# Load gpt2
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
# %% [markdown]
# ## GPT-2 Model Architecture
# %% [markdown]
# Let's take a look at GPT-2's architecture. GPT-2 has 12 layers, accessed as `model.transformer.h`.
# %%
# print(model)
# %% [markdown]
# ## Apply Logit Lens
#
#
# %% [markdown]
# To apply logit lens, we collect activations at each layer's output, apply layer normalization (`model.transformer.ln_f`), and then process through the model's head (`model.lm_head`) to get the logits. Next, we apply the softmax to the logits to obtain output token probabilities.
#
# By observing different layers' output token probabilities, logit lens provides insights into the model's confidence throughout processing steps.
# %%
# Define the financial prompt
prompt= "With good earnings the stock price of company will likely"
# define the words for which we will compare the logits
logit_words = ["rise", "fall"]
layers = model.transformer.h
probs_layers = []
logits_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        for layer_idx, layer in enumerate(layers):
            # Process layer output through the model's head and layer normalization
            layer_output = model.lm_head(model.transformer.ln_f(layer.output[0]))

            # Apply softmax to obtain probabilities and save the result
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)
            logits_layers.append(layer_output.save())


probs = torch.cat([probs.value for probs in probs_layers])
logits = torch.cat([logits.value for logits in logits_layers])

# Find the maximum probability and corresponding tokens for each position
max_probs, tokens = probs.max(dim=-1)


# Decode token IDs to words for each layer
words = [[model.tokenizer.decode(t.cpu()).encode("unicode_escape").decode() for t in layer_tokens]
    for layer_tokens in tokens]

# Access the 'input_ids' attribute of the invoker object to get the input words
input_words = [model.tokenizer.decode(t) for t in invoker.inputs[0][0]["input_ids"][0]]

# Get Logits for specified words
logit_tokens = [model.tokenizer.encode(word)[0] for word in logit_words]
logit_values = logits[:,:,logit_tokens]

# Calculate logit difference between the two words
logit_diff = logit_values[:,:,0] - logit_values[:,:,1]

# Average the logit difference across all input tokens for each layer
avg_logit_diff_per_layer = torch.mean(logit_diff, dim=1)
# %%
import plotly.graph_objects as go
import plotly.io as pio
if is_colab:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "plotly_mimetype+notebook_connected+notebook"

fig = go.Figure(data=[go.Bar(x=list(range(len(avg_logit_diff_per_layer))), y=avg_logit_diff_per_layer.detach().cpu().numpy())])

fig.update_layout(
    title="Average Logit Difference Between 'rise' and 'fall' Per Layer",
    xaxis_title="Layer",
    yaxis_title="Average Logit Difference (rise - fall)",
)

fig.show()