In [None]:
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

### Indirect Object Identification


In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

In [None]:
example_answer = " Mary"

In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens)
print(prompts)
print(answers)

['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]


In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)

In [None]:
tokens

tensor([[50256,  2215,  1757,   290,  5335,  1816,   284,   262, 12437,    11,
          1757,  2921,   262,  6131,   284],
        [50256,  2215,  1757,   290,  5335,  1816,   284,   262, 12437,    11,
          5335,  2921,   262,  6131,   284],
        [50256,  2215,  4186,   290,  3700,  1816,   284,   262,  3952,    11,
          3700,  2921,   262,  2613,   284],
        [50256,  2215,  4186,   290,  3700,  1816,   284,   262,  3952,    11,
          4186,  2921,   262,  2613,   284],
        [50256,  2215,  6035,   290, 15686,  1816,   284,   262, 12437,    11,
         15686,  2921,   281, 17180,   284],
        [50256,  2215,  6035,   290, 15686,  1816,   284,   262, 12437,    11,
          6035,  2921,   281, 17180,   284],
        [50256,  3260,  5780,   290, 14235,  1816,   284,   262,  3952,    11,
         14235,  2921,   257,  4144,   284],
        [50256,  3260,  5780,   290, 14235,  1816,   284,   262,  3952,    11,
          5780,  2921,   257,  4144,   284]])

In [None]:
original_logits, cache = model.run_with_cache(tokens)

### Direct Logit Attribution

##### Flashcard 1

In [None]:
W_U = model.W_U

In [None]:
tokens = torch.tensor([69, 420])

In [None]:
tokens

tensor([ 69, 420])

In [None]:
W_U.shape

torch.Size([768, 50257])

Give `W_U` is the unembedding matrix

Extract the residual directions of `tokens`

In [None]:
residual_directions = W_U[:, tokens]

In [None]:
residual_directions.shape

torch.Size([768, 2])

##### Example 2

In [None]:
correct_token = torch.tensor(5335)

In [None]:
incorrect_token = torch.tensor(1757)

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
correct_token, incorrect_token

(tensor(5335), tensor(1757))

Compute the logit difference of the residual directions of two tokens above

In [None]:
correct_residual_direction = model.tokens_to_residual_directions(correct_token)

In [None]:
correct_residual_direction.shape

torch.Size([768])

In [None]:
incorrect_residual_direction = model.tokens_to_residual_directions(incorrect_token)

In [None]:
incorrect_residual_direction.shape

torch.Size([768])

In [None]:
logit_diff_direction = correct_residual_direction - incorrect_residual_direction

In [None]:
logit_diff_direction.shape

torch.Size([768])

##### Example 3

In [None]:
unembedding = model.W_U

`768` is the dimension of an embedding vector, `50257` is the number of vocabulary

In [None]:
unembedding.shape

torch.Size([768, 50257])

In [None]:
correct_token, incorrect_token

(tensor(5335), tensor(1757))

Compute the logit difference of the residual directions of two tokens above

In [None]:
correct_residual_direction = unembedding[:, correct_token]

In [None]:
correct_residual_direction.shape

torch.Size([768])

In [None]:
incorrect_residual_direction = unembedding[:, incorrect_token]

In [None]:
incorrect_residual_direction.shape

torch.Size([768])

In [None]:
logit_diff_direction = correct_residual_direction - incorrect_residual_direction

In [None]:
logit_diff_direction.shape

torch.Size([768])

##### Example 4

In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)

In [None]:
original_logits, cache = model.run_with_cache(tokens)

In [None]:
cache.compute_head_results()

In [None]:
import torch
from einops import rearrange

In [None]:
type(cache)

transformer_lens.ActivationCache.ActivationCache

In [None]:
[key for key in cache.keys() if key.startswith("blocks.0.attn")]

['blocks.0.attn.hook_q',
 'blocks.0.attn.hook_k',
 'blocks.0.attn.hook_v',
 'blocks.0.attn.hook_attn_scores',
 'blocks.0.attn.hook_pattern',
 'blocks.0.attn.hook_z',
 'blocks.0.attn.hook_result']

Stack the output of all attention layer from layer `0` to layer `2` as bellow. Explain

In [None]:
outputs = []

In [None]:
for layer in range(3):
    outputs.append(cache[f"blocks.{layer}.attn.hook_result"])

The `outputs` list contains the attention outputs from three layers. The output shape for the first layer is `[8, 15, 12, 768]`, which corresponds to a batch size of 8, a sequence length of 15, 12 attention heads, and an embedding of 768.

In [None]:
len(outputs)

3

In [None]:
outputs[0].shape

torch.Size([8, 15, 12, 768])

The goal is to combine all the attention heads from the different layers, which are located along dimension 2 (or the last two dimensions). To do this, the `outputs` are concatenated along dimension 2.

In [None]:
outputs = torch.cat(outputs, dim=-2)

After concatenation, the resulting output shape is `[8, 15, 36, 768]`, which corresponds to a batch size of 8, a sequence length of 15, 36 attention heads (combined from all layers).

In [None]:
outputs.shape

torch.Size([8, 15, 36, 768])

##### Example 5

In [None]:
outputs = rearrange(outputs, "... concat_head_index d_model -> concat_head_index ... d_model")

In [None]:
outputs.shape

torch.Size([36, 8, 15, 768])

##### Example 6

In [None]:
type(cache)

transformer_lens.ActivationCache.ActivationCache

Extract the residual direction of all heads from layer `0` to layer `2` using `transformer_lens`

In [None]:
per_head_residual = cache.stack_head_results(layer=2)

In [None]:
per_head_residual.shape

torch.Size([24, 8, 15, 768])

##### Example 6.1

In [None]:
x = torch.arange(0, 8).view(2, 4)
y = torch.arange(0, 8).view(2, 4)

In [None]:
x.shape, y.shape

(torch.Size([2, 4]), torch.Size([2, 4]))

`x` and `y` have batch size `2`, and dimension `4`

In [None]:
x

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [None]:
y

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [None]:
output = (x * y).sum()

In [None]:
output

tensor(140)

Perform a single operation on x and y using `einops` to do element-wise multiplication between them, and then sum up the resulting value

In [None]:
from einops import einsum

In [None]:
einops_output = einsum(x, y, "batch dim, batch dim ->")

In [None]:
einops_output

tensor(140)

In [None]:
output == einops_output

tensor(True)

##### Example 6.2

In [None]:
_logits, _cache = model.run_with_cache(tokens)

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
_logits.shape

torch.Size([8, 15, 50257])

Manually calculate the logits of `tokens` by going through each component of `model` individually

In [None]:
embeddings = model.pos_embed(tokens) + model.embed(tokens)

In [None]:
residual = embeddings

In [None]:
for block in model.blocks:
    residual = block(residual)

In [None]:
residual = model.ln_final(residual)

In [None]:
manual_logits = model.unembed(residual)

In [None]:
torch.allclose(manual_logits, _logits)

True

##### Example 6.3

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
_embeddings = model.embed(tokens)

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
_embeddings.shape

torch.Size([8, 15, 768])

Retrieve the text embedding matrix of the `model` and compute the text embedding of `tokens` using it

In [None]:
W_E = model.W_E

In [None]:
W_E.shape

torch.Size([50257, 768])

In [None]:
_manual_embeddings = model.W_E[tokens, :]

In [None]:
torch.allclose(_embeddings, _manual_embeddings)

True

##### Example 6.4

In [None]:
len(tokens)

8

In [None]:
model.W_pos[0:]

##### Example 7

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)

`per_head_residual` is ____

In [None]:
per_head_residual.shape

torch.Size([144, 8, 768])

Calculate the logit difference. Explaindef residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

In [None]:
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)

In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [None]:
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)