In [373]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

Disabled automatic differentiation


In [375]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [374]:
text = "Paris is the capital city of"
answer = "France"
model.to_str_tokens(text)
utils.test_prompt(text, answer, model)
tokens = model.to_tokens(text)

Tokenized prompt: ['<|endoftext|>', 'Paris', ' is', ' the', ' capital', ' city', ' of']
Tokenized answer: [' France']


Top 0th token. Logit: 17.61 Prob: 49.00% Token: | France|
Top 1th token. Logit: 16.66 Prob: 18.92% Token: | the|
Top 2th token. Logit: 14.85 Prob:  3.09% Token: | French|
Top 3th token. Logit: 14.69 Prob:  2.64% Token: | Paris|
Top 4th token. Logit: 14.58 Prob:  2.36% Token: | a|
Top 5th token. Logit: 14.55 Prob:  2.30% Token: | Europe|
Top 6th token. Logit: 14.17 Prob:  1.57% Token: | Belgium|
Top 7th token. Logit: 13.17 Prob:  0.58% Token: | Morocco|
Top 8th token. Logit: 12.85 Prob:  0.42% Token: | one|
Top 9th token. Logit: 12.61 Prob:  0.33% Token: | Spain|


In [396]:
logits, cache = model.run_with_cache(text, remove_batch_dim= True)


In [397]:
token_direction =model.tokens_to_residual_directions(answer)
line(token_direction, title="Token direction")

In [403]:
data, labels  = cache.accumulated_resid(apply_ln = True, return_labels= True)
print(data.shape, labels)

torch.Size([13, 7, 768]) ['0_pre', '1_pre', '2_pre', '3_pre', '4_pre', '5_pre', '6_pre', '7_pre', '8_pre', '9_pre', '10_pre', '11_pre', 'final_post']


In [405]:
imshow(data@token_direction, title="Residuals", x = model.to_str_tokens(text), y = labels)
line((data@token_direction)[:,-1], title="Residuals sum", x = labels)

In [406]:
resid_components, labels = cache.decompose_resid(apply_ln = True, return_labels = True)
print(resid_components.shape, labels)

torch.Size([26, 7, 768]) ['embed', 'pos_embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out', '6_attn_out', '6_mlp_out', '7_attn_out', '7_mlp_out', '8_attn_out', '8_mlp_out', '9_attn_out', '9_mlp_out', '10_attn_out', '10_mlp_out', '11_attn_out', '11_mlp_out']


In [407]:

imshow(resid_components@token_direction, title="Layerwise", x = model.to_str_tokens(text))
line((resid_components@token_direction)[:,-1], title="Layerwise final token", x = labels)

In [414]:
for key in list(cache.keys()):
    if 'attn.hook_pattern' in key:
        for head_no in range(model.cfg.n_heads):
            imshow(cache[key][head_no], title=f"{key}-{head_no}", x = tokens, y = tokens)

torch.Size([12, 7, 7])


ValueError: The length of the y vector must match the length of the first dimension of the img matrix.