From:

https://colab.research.google.com/drive/1nD6tfM33StbAqXG5HnYPlC40hKSj8mzD#scrollTo=YzmdOdeJIAiY

# Setup
(No need to read)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2489955c30>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# test

In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

prompts = []
answers = []
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).cuda()

tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.cuda()
original_logits, cache = model.run_with_cache(tokens)

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff.item())

Per prompt logit difference: tensor([3.3367, 3.2016, 2.7094, 3.7974, 1.7204, 5.2812, 2.6008, 5.7674],
       device='cuda:0')
Average logit difference: 3.5518715381622314


In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


In [None]:
corrupted_prompts = []
for i in range(0, len(prompts), 2):
    corrupted_prompts.append(prompts[i+1])
    corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-3.5519, device='cuda:0')
Clean Average Logit Diff tensor(3.5519, device='cuda:0')


In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("resid_pre", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", labels={"x":"Head", "y":"Layer"})

# Analyze parts of patching code

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
print("Average logit difference:", original_average_logit_diff.item())

## .gather()

In [None]:
input = torch.tensor([[1, 2], [3, 4], [5, 6]])
input

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [None]:
index = torch.tensor([[0, 2], [1, 0]])
index

tensor([[0, 2],
        [1, 0]])

In [None]:
torch.gather(input, dim=0, index=index)

tensor([[1, 6],
        [3, 2]])

https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms
The first row, [0,2] says to get the 0th index of col 0, which is 1. The "2" says to get the 2nd index of col 1, which is 6.

In [None]:
torch.gather(input, dim=1, index=index)

RuntimeError: ignored

The first row, [0,2] says to get the 0th index of row 0, which is 1. The "1" says to get the 2nd index of row 1, which doesn't exist.

In [None]:
input.shape

torch.Size([3, 2])

In [None]:
torch.gather(input, dim=-1, index=index)

RuntimeError: ignored

## final logits

In [None]:
original_logits.shape

torch.Size([8, 15, 50257])

In [None]:
final_logits = original_logits[:, -1, :]
final_logits.shape

torch.Size([8, 50257])

```
original_logits[:, -1, :]
```
There are 15 tokens in the sequence. This takes the logits for the last token, which is used to predict.


In [None]:
final_logits.gather(dim=-1, index=answer_tokens).shape

torch.Size([8, 2])

This gets the values of the logits for the answer_tokens. Each token ID correspond to vocab size, which is in the last dim. So .gather() obtains the values based on an index.

In [None]:
final_logits.gather(dim=-1, index=answer_tokens)

tensor([[18.1932, 14.8565],
        [18.0346, 14.8330],
        [15.8823, 13.1728],
        [16.7980, 13.0005],
        [16.2288, 14.5083],
        [16.2888, 11.0076],
        [17.0133, 14.4125],
        [17.9510, 12.1836]], device='cuda:0')

## Head output

In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.cuda()
original_logits, cache = model.run_with_cache(tokens)

In [None]:
cache['blocks.1.hook_resid_pre'].shape

torch.Size([8, 15, 768])

In [None]:
head_output = cache['blocks.1.hook_resid_pre']
head_output.shape

torch.Size([8, 15, 768])

768 is dim size, not vocab size. So we cannot use .gather(index=answer_tokens) on it, else we get a CUDA out of bounds error and have to restart runtime.

In [None]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='cuda:0')

In [None]:
last_token_head_output = head_output[:, -1, :]
last_token_head_output[0, 5335]

IndexError: ignored

In [None]:
for layer in list(cache.keys())[0:20]:
    print(layer, cache[layer].shape)

hook_embed torch.Size([8, 15, 768])
hook_pos_embed torch.Size([8, 15, 768])
blocks.0.hook_resid_pre torch.Size([8, 15, 768])
blocks.0.ln1.hook_scale torch.Size([8, 15, 1])
blocks.0.ln1.hook_normalized torch.Size([8, 15, 768])
blocks.0.attn.hook_q torch.Size([8, 15, 12, 64])
blocks.0.attn.hook_k torch.Size([8, 15, 12, 64])
blocks.0.attn.hook_v torch.Size([8, 15, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([8, 12, 15, 15])
blocks.0.attn.hook_pattern torch.Size([8, 12, 15, 15])
blocks.0.attn.hook_z torch.Size([8, 15, 12, 64])
blocks.0.hook_attn_out torch.Size([8, 15, 768])
blocks.0.hook_resid_mid torch.Size([8, 15, 768])
blocks.0.ln2.hook_scale torch.Size([8, 15, 1])
blocks.0.ln2.hook_normalized torch.Size([8, 15, 768])
blocks.0.mlp.hook_pre torch.Size([8, 15, 3072])
blocks.0.mlp.hook_post torch.Size([8, 15, 3072])
blocks.0.hook_mlp_out torch.Size([8, 15, 768])
blocks.0.hook_resid_post torch.Size([8, 15, 768])
blocks.1.hook_resid_pre torch.Size([8, 15, 768])


We see that NONE of them have a vocab size, given that there is no unembedding. It seems that unembedding has to be done as external to the network. That is, we should apply it ourselves.

In [None]:
def head_output_to_ave(cache, answer_tokens, per_prompt=False):
    # final_logits = logits[:, -1, :]
    # answer_logits = final_logits.gather(dim=-1, index=answer_tokens)

    head_output = cache['blocks.1.hook_resid_pre'][:, -1, :]
    answer_head_output = head_output.gather(dim=-1, index=answer_tokens)

    answer_logit_diff = answer_head_output[:, 0] - answer_head_output[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()