In [None]:
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

### Indirect Object Identification


In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

In [None]:
example_answer = " Mary"

In [None]:
utils.test_prompt(example_prompt, " John", model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens)
print(prompts)
print(answers)

['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
[(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]


In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)

### Direct Logit Attribution

In [None]:
original_logits, cache = model.run_with_cache(tokens)

In [None]:
from einops import rearrange

In [None]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

In [None]:
answer_residual_directions.shape

torch.Size([8, 2, 768])

In [None]:
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]

In [None]:
logit_diff_directions.shape

torch.Size([8, 768])

In [None]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [None]:
answer_tokens.shape

torch.Size([8, 2])

In [None]:
final_residual_stream = cache["resid_post", -1]

In [None]:
final_residual_stream.shape

torch.Size([8, 15, 768])

In [None]:
final_token_residual_stream = final_residual_stream[:, -1, :]

In [None]:
final_token_residual_stream.shape

torch.Size([8, 768])

In [None]:
# scaled_final_token_residual_stream = cache.apply_ln_to_stack(fin)

In [None]:
final_residual_stream

tensor([[[-7.3107e-01,  5.6513e-01, -1.9252e+00,  ...,  5.5697e-01,
           1.0524e+00, -4.5406e-01],
         [ 1.0534e+00, -4.7922e+00,  8.2390e+00,  ...,  4.0869e-01,
          -1.5729e+00, -9.4340e+00],
         [ 2.2844e-01,  7.0241e+00,  3.8416e+00,  ...,  2.3653e+00,
          -3.0928e+00,  4.4728e-01],
         ...,
         [ 8.3668e+00, -4.1369e+00,  5.3449e+00,  ...,  1.7024e+00,
          -2.4825e+00, -1.7049e+01],
         [ 5.7748e+00, -4.0720e-01, -1.0594e+01,  ...,  2.9123e-01,
          -3.0663e+00, -2.6462e+00],
         [ 9.2831e+00, -5.0196e+00,  8.3061e-01,  ..., -1.3695e+01,
           2.4261e+00, -1.2054e+00]],

        [[-7.3107e-01,  5.6513e-01, -1.9252e+00,  ...,  5.5697e-01,
           1.0524e+00, -4.5406e-01],
         [ 1.0534e+00, -4.7922e+00,  8.2390e+00,  ...,  4.0869e-01,
          -1.5729e+00, -9.4340e+00],
         [ 2.2844e-01,  7.0241e+00,  3.8416e+00,  ...,  2.3653e+00,
          -3.0928e+00,  4.4728e-01],
         ...,
         [ 7.6015e+00, -2

##### Draft 3

In [None]:
tokens.shape

torch.Size([8, 15])

In [None]:
final_residual_stream = cache["resid_post", -1]

In [None]:
final_residual_stream.shape

torch.Size([8, 15, 768])

In [None]:
final_token_residual_stream = final_residual_stream[:, -1, :]

In [None]:
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

In [None]:
final_residual_stream.shape

torch.Size([8, 15, 768])

`scaled_final_token_residual_stream` represent the residual stream of the final token in `tokens`

In [None]:
scaled_final_token_residual_stream.shape

torch.Size([8, 768])

In [None]:
tokens.shape

torch.Size([8, 15])

`logit_diff_directions` is the difference in residual stream of `answer_tokens`

In [None]:
answer_tokens.shape

torch.Size([8, 2])

In [None]:
logit_diff_directions.shape

torch.Size([8, 768])

In [None]:
from einops import einsum

In [None]:
logit_diff = einsum(
    scaled_final_token_residual_stream,
    logit_diff_directions,
    "batch d_model, batch d_model ->"
)

In [None]:
logit_diff

tensor(28.4150, grad_fn=<ViewBackward0>)

In [None]:
average_logit_diff = logit_diff / len(prompts)

In [None]:
average_logit_diff

tensor(3.5519, grad_fn=<DivBackward0>)

In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

`logits_to_ave_logit_diff`
- obtain `cache` using `tokens`
- extract the logit of the final token from `cache`
- extract the logits of the `answer_tokens`
- minus the logit difference of two tokens in `answer_tokens

`other`
- obtain `cache` using `tokens`
- extract the `answer_residual_directions` from `cache`
- `logit_diff_directions` from the residual directions`, just like minius

In [None]:
original_logits.shape

torch.Size([8, 15, 50257])

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)

In [None]:
original_average_logit_diff

tensor(3.5519, grad_fn=<MeanBackward0>)

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]


final_residual_stream = cache["resid_post", -1]

final_token_residual_stream = final_residual_stream[:, -1, :]

scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream,
    layer = -1,
    pos_slice=-1
)

average_logit_diff = einsum(
    scaled_final_token_residual_stream,
    logit_diff_directions,
    "batch d_model, batch d_model -> "
)/len(prompts)

In [None]:
average_logit_diff

tensor(3.5519, grad_fn=<DivBackward0>)

##### Tokens -> logits

In [None]:
model.embed

Embed()

In [None]:
model.pos_embed

PosEmbed()

In [None]:
embed = self.hook_embed(self.embed(tokens)

if self.cfg.positional_embedding_type == "standard":
    pos_embed = self.hook_pos_embed(
        self.pos_embed(tokens, pos_offset)
    )  # [batch, pos, d_model]
    residual = embed + pos_embed

transformer_block_list = self.blocks


residual = block(
    residual,
    past_kv_cache_entry=None
    shortformer_pos_embed=None,
)  # [batch, pos, d_model]


residual = self.ln_final(residual)# [batch, pos, d_model] 
logits = self.unembed(residual) # [batch, pos, d_vocab]