In [None]:
IN_COLAB = False
DEVELOPMENT_MODE = False

In [None]:
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled>

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
model_description_text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly. 

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""
loss = model(model_description_text, return_type="loss")
print("Model loss:", loss)

Model loss: tensor(4.1758)


In [None]:
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = model.to_tokens(gpt2_text)
print(gpt2_tokens.device)
gpt2_logits, gpt2_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True)

cpu


In [None]:
print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_tokens(gpt2_text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 33, 33])


In [None]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


### Hooks: Intervening on Activations

##### Example 1

In [None]:
text_ids = model.to_tokens("Persistence is all you need.")

In [None]:
layer_to_ablate = 0

In [None]:
head_index_to_ablate = 8

In [None]:
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

In [None]:
original_loss = model(input=text_ids, return_type="loss")

In [None]:
original_loss

tensor(3.6956)

In [None]:
ablated_loss = model.run_with_hooks(
    input=text_ids,
    return_type="loss",
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate),
        head_ablation_hook
    )]
)

Shape of the value tensor: torch.Size([1, 8, 12, 64])


In [None]:
ablated_loss

tensor(4.0071)

##### Example 2

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"

In [None]:
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

In [None]:
correct_answer = " John"
incorrect_answer = " Mary"

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [None]:
corrupted_prompt

'After John and Mary went to the store, John gave a bottle of milk to'

In [None]:
clean_prompt

'After John and Mary went to the store, Mary gave a bottle of milk to'

In [None]:
correct_answer, incorrect_answer

(' John', ' Mary')

Find the difference in logits between `correct_answer` and `incorrect_answer`

In [None]:
clean_tokens = model.to_tokens(clean_prompt)

In [None]:
def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

In [None]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

In [None]:
logits_to_logit_diff(clean_logits, correct_answer=" John", incorrect_answer=" Mary")

tensor(4.2765)

In [None]:
corrupted_logits = model(corrupted_tokens)

In [None]:
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits, correct_answer=" John", incorrect_answer=" Mary")

In [None]:
corrupted_logit_diff

tensor(-2.7376)

In [None]:
# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 4.276
Corrupted logit difference: -2.738


In [None]:
def residual_stream_patching_hook(
    resid_pre: Float[torch.Tensor, "batch seq_len d_model"],
    hook: HookPoint,
    position: int
) -> Float[torch.Tensor, "batch seq_len d_model"]:
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

In [None]:
n_positions = len(clean_tokens[0])

In [None]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary

# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/12 [00:00<?, ?it/s]

In [None]:
model.cfg.n_layers

12

In [None]:
%matplotlib inline
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")

### Hooks: Accessing Activations


##### Example 1

In [None]:
batch_size = 10
seq_len = 50

In [None]:
random_tokens = torch.randint(1000, 10000, (batch_size, seq_len)).to(model.cfg.device)

In [None]:
repeated_tokens = einops.repeat(random_tokens, "batch_size seq_len -> batch_size (2 seq_len)")

In [None]:
repeated_logits = model(repeated_tokens)

In [None]:
correct_log_probs = model.loss_fn(repeated_logits, repeated_tokens, per_token=True)

In [None]:
loss_by_position = einops.reduce(correct_log_probs, "batch position -> position", "mean")

In [None]:
line(loss_by_position, xaxis="Position", yaxis="Loss", title="Loss by position on random repeated tokens")


In [None]:
def induction_score_hook(
    acts: Float[torch.Tensor, "batch_size head_idx des_pos source_pos"],
    hook: HookPoint
):
    print(hook.name)

In [None]:
pattern_hook_names_filter = lambda name: name.endswith("pattern")

In [None]:
model.run_with_hooks(
    repeated_tokens,
    return_type=None,
    fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)]
)

blocks.0.attn.hook_pattern
blocks.1.attn.hook_pattern
blocks.2.attn.hook_pattern
blocks.3.attn.hook_pattern
blocks.4.attn.hook_pattern
blocks.5.attn.hook_pattern
blocks.6.attn.hook_pattern
blocks.7.attn.hook_pattern
blocks.8.attn.hook_pattern
blocks.9.attn.hook_pattern
blocks.10.attn.hook_pattern
blocks.11.attn.hook_pattern


In [None]:
from IPython.core.debugger import set_trace

In [None]:
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
def induction_score_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

model.run_with_hooks(
    repeated_tokens, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head", yaxis="Layer", title="Induction Score by Head")

In [None]:
induction_score_store.shape

torch.Size([12, 12])

In [None]:
induction_head_layer = 5
induction_head_index = 5
single_random_sequence = torch.randint(1000, 10000, (1, 20)).to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")
def visualize_pattern_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)

### Layer Norm

Given the name of the last layer in transformer is 

In [None]:
unembed_bias = model.unembed.b_U

In [None]:
bias_values, bias_indicies = unembed_bias.sort(descending=True)

In [None]:
bias_values

tensor([ 7.0297,  6.9815,  6.6844,  ..., -3.8378, -3.8381, -3.8446])

In [None]:
model.get_token_position(" cat", "The cat sat on the mat")

2

### Gotcha: prepend_bos


In [None]:
prompt = "Claire and Mary went to the shops, then Mary gave a bottle of milk to"

In [None]:
ioi_logits_with_bos = model(prompt, prepend_bos=True)

In [None]:
mary_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(" Mary")].item()

In [None]:
mary_logit_with_bos

12.493572235107422

In [None]:
print(f"Logit difference with BOS: {(claire_logit_with_bos - mary_logit_with_bos):.3f}")

In [None]:
claire_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(" Claire")].item()

In [None]:
claire_logit_with_bos

In [None]:
ioi_logits_without_bos = model("Claire and Mary went to the shops, then Mary gave a bottle of milk to", prepend_bos=False)

In [None]:
mary_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(" Mary")].item()
claire_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(" Claire")].item()

In [None]:
print(f"Logit difference without BOS: {(claire_logit_without_bos - mary_logit_without_bos):.3f}")

In [None]:
model.to_str_tokens(" Claire", prepend_bos=False)

In [None]:
model.to_str_tokens('Claire', prepend_bos=False)

### Hook Points


In [None]:
from transformer_lens.hook_points import HookedRootModule, HookPoint

In [None]:
class SquareThenAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()

    def forward(self, x):
        # The hook_square doesn't change the value, but lets us access it
        square = self.hook_square(x * x)
        return self.offset + square

In [None]:
model = SquareAndAdd(10.)

In [None]:
model

In [None]:
class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.layer1 = SquareAndAdd(3.0)
        self.layer2 = SquareAndAdd(-4.0)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()
        super().setup()
    
    def forward(self, x):
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.layer1(x_in))
        x_out = self.hook_out(self.layer2(x_mid))
        return x_out

In [None]:
model = TwoLayerModel()

In [None]:
out, act = model.run_with_cache(torch.tensor(5.0))

In [None]:
act