In [2]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

  from .autonotebook import tqdm as notebook_tqdm


Disabled automatic differentiation


In [3]:
model = HookedTransformer.from_pretrained("gpt2-small")

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [187]:

cities = ["Delhi", "Paris", "Tokyo", "Berlin", "Ottawa", "Madrid", "Mumbai", "Bangalore"]
countries = [" India", " France", " Japan", " Germany", " Canada", " Spain", " India", " India"]
prompts = []

for city, country in zip(cities, countries):
    prompts.append(f"The city {city} is the capital of")
answers = countries
wrong_answers = [" Delhi", " Paris", " Tokyo", " Berlin", " Rome", " Madrid", " Mumbai", " Bangalore"]
for prompt, answer in zip(prompts, answers):
    utils.test_prompt(prompt, answer, model)
# for prompt, answer in zip(prompts, wrong_answers):
#     utils.test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Delhi', ' is', ' the', ' capital', ' of']
Tokenized answer: [' India']


Top 0th token. Logit: 15.75 Prob: 30.58% Token: | India|
Top 1th token. Logit: 14.98 Prob: 14.18% Token: | the|
Top 2th token. Logit: 13.96 Prob:  5.08% Token: |,|
Top 3th token. Logit: 13.74 Prob:  4.09% Token: | a|
Top 4th token. Logit: 13.28 Prob:  2.59% Token: | and|
Top 5th token. Logit: 13.24 Prob:  2.48% Token: | Uttar|
Top 6th token. Logit: 13.24 Prob:  2.47% Token: |.|
Top 7th token. Logit: 12.98 Prob:  1.91% Token: | Delhi|
Top 8th token. Logit: 12.68 Prob:  1.42% Token: | one|
Top 9th token. Logit: 12.31 Prob:  0.98% Token: | which|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Paris', ' is', ' the', ' capital', ' of']
Tokenized answer: [' France']


Top 0th token. Logit: 15.94 Prob: 29.49% Token: | France|
Top 1th token. Logit: 15.53 Prob: 19.55% Token: | the|
Top 2th token. Logit: 14.83 Prob:  9.67% Token: | a|
Top 3th token. Logit: 13.41 Prob:  2.35% Token: | Paris|
Top 4th token. Logit: 13.40 Prob:  2.33% Token: | French|
Top 5th token. Logit: 13.08 Prob:  1.69% Token: | one|
Top 6th token. Logit: 12.90 Prob:  1.42% Token: | Belgium|
Top 7th token. Logit: 12.76 Prob:  1.22% Token: | an|
Top 8th token. Logit: 12.69 Prob:  1.15% Token: | Europe|
Top 9th token. Logit: 12.49 Prob:  0.94% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Tokyo', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Japan']


Top 0th token. Logit: 15.06 Prob: 26.15% Token: | Japan|
Top 1th token. Logit: 14.61 Prob: 16.73% Token: | the|
Top 2th token. Logit: 14.12 Prob: 10.16% Token: |,|
Top 3th token. Logit: 13.68 Prob:  6.60% Token: | Tokyo|
Top 4th token. Logit: 13.31 Prob:  4.56% Token: |.|
Top 5th token. Logit: 13.19 Prob:  4.01% Token: | and|
Top 6th token. Logit: 13.11 Prob:  3.72% Token: | a|
Top 7th token. Logit: 11.81 Prob:  1.01% Token: | one|
Top 8th token. Logit: 11.52 Prob:  0.76% Token: | which|
Top 9th token. Logit: 11.39 Prob:  0.67% Token: | an|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Berlin', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Germany']


Top 0th token. Logit: 16.89 Prob: 42.67% Token: | Germany|
Top 1th token. Logit: 16.32 Prob: 24.16% Token: | the|
Top 2th token. Logit: 14.38 Prob:  3.48% Token: | Berlin|
Top 3th token. Logit: 14.38 Prob:  3.47% Token: | a|
Top 4th token. Logit: 13.78 Prob:  1.91% Token: | German|
Top 5th token. Logit: 13.28 Prob:  1.16% Token: | Europe|
Top 6th token. Logit: 12.93 Prob:  0.82% Token: | Austria|
Top 7th token. Logit: 12.93 Prob:  0.81% Token: | Bav|
Top 8th token. Logit: 12.64 Prob:  0.61% Token: | one|
Top 9th token. Logit: 12.50 Prob:  0.53% Token: | Russia|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Ottawa', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Canada']


Top 0th token. Logit: 15.88 Prob: 32.43% Token: |,|
Top 1th token. Logit: 15.31 Prob: 18.38% Token: |.|
Top 2th token. Logit: 14.87 Prob: 11.83% Token: | and|
Top 3th token. Logit: 13.95 Prob:  4.72% Token: | in|
Top 4th token. Logit: 13.33 Prob:  2.54% Token: | Canada|
Top 5th token. Logit: 12.70 Prob:  1.35% Token: | is|
Top 6th token. Logit: 12.41 Prob:  1.00% Token: | on|
Top 7th token. Logit: 12.39 Prob:  0.99% Token: |
|
Top 8th token. Logit: 12.34 Prob:  0.94% Token: | –|
Top 9th token. Logit: 12.34 Prob:  0.94% Token: | but|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Madrid', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Spain']


Top 0th token. Logit: 17.44 Prob: 58.42% Token: | Spain|
Top 1th token. Logit: 15.64 Prob:  9.62% Token: | the|
Top 2th token. Logit: 14.96 Prob:  4.87% Token: | Catalonia|
Top 3th token. Logit: 14.61 Prob:  3.43% Token: | a|
Top 4th token. Logit: 14.03 Prob:  1.93% Token: | Madrid|
Top 5th token. Logit: 13.66 Prob:  1.33% Token: | Portugal|
Top 6th token. Logit: 13.10 Prob:  0.76% Token: | one|
Top 7th token. Logit: 12.98 Prob:  0.68% Token: | Barcelona|
Top 8th token. Logit: 12.97 Prob:  0.67% Token: | Argentina|
Top 9th token. Logit: 12.94 Prob:  0.65% Token: | Spanish|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Mumbai', ' is', ' the', ' capital', ' of']
Tokenized answer: [' India']


Top 0th token. Logit: 16.51 Prob: 34.79% Token: | India|
Top 1th token. Logit: 15.86 Prob: 18.24% Token: | the|
Top 2th token. Logit: 14.60 Prob:  5.14% Token: | Maharashtra|
Top 3th token. Logit: 14.59 Prob:  5.12% Token: | a|
Top 4th token. Logit: 13.66 Prob:  2.02% Token: | Gujarat|
Top 5th token. Logit: 13.59 Prob:  1.87% Token: | Mumbai|
Top 6th token. Logit: 13.52 Prob:  1.76% Token: | Pakistan|
Top 7th token. Logit: 13.21 Prob:  1.29% Token: | Uttar|
Top 8th token. Logit: 13.14 Prob:  1.20% Token: | one|
Top 9th token. Logit: 12.80 Prob:  0.86% Token: | an|


Tokenized prompt: ['<|endoftext|>', 'The', ' city', ' Bangalore', ' is', ' the', ' capital', ' of']
Tokenized answer: [' India']


Top 0th token. Logit: 16.53 Prob: 39.53% Token: | India|
Top 1th token. Logit: 15.84 Prob: 19.94% Token: | the|
Top 2th token. Logit: 14.49 Prob:  5.14% Token: | a|
Top 3th token. Logit: 14.26 Prob:  4.08% Token: | Karn|
Top 4th token. Logit: 13.26 Prob:  1.50% Token: | Bangalore|
Top 5th token. Logit: 13.13 Prob:  1.32% Token: | one|
Top 6th token. Logit: 12.96 Prob:  1.11% Token: | an|
Top 7th token. Logit: 12.91 Prob:  1.06% Token: | Tamil|
Top 8th token. Logit: 12.83 Prob:  0.98% Token: | Maharashtra|
Top 9th token. Logit: 12.65 Prob:  0.82% Token: | Gujarat|


In [188]:
model.to_tokens(wrong_answers, prepend_bos= False)

tensor([[12517],
        [ 6342],
        [11790],
        [11307],
        [10598],
        [14708],
        [22917],
        [46216]], device='mps:0')

In [189]:
logits, cache = model.run_with_cache(prompts)


answer_tokens = model.to_tokens(answers, prepend_bos= False).squeeze(-1)
wrong_answer_tokens = model.to_tokens(wrong_answers, prepend_bos= False).squeeze(-1)
answer_token_direction = model.W_U.T[answer_tokens] - model.W_U.T[wrong_answer_tokens]
answer_token_direction = model.W_U.T[answer_tokens]
answer_token_direction.shape



torch.Size([8, 768])

In [183]:
model.to_str_tokens(answer_tokens), model.to_str_tokens(wrong_answer_tokens)

([' India',
  ' France',
  ' Japan',
  ' Germany',
  ' Canada',
  ' Spain',
  ' India',
  ' India'],
 [' Delhi',
  ' Paris',
  ' Tokyo',
  ' Berlin',
  ' Rome',
  ' Madrid',
  ' Mumbai',
  ' Bangalore'])

In [190]:
print(logits[5][-1][answer_tokens[5]], logits[5][-1][wrong_answer_tokens[5]])
plot_top_tokens_from_res_stream(logits[5][-1], False)

tensor(17.4420, device='mps:0') tensor(14.0314, device='mps:0')


In [179]:
from sympy import false


def plot_top_tokens_from_res_stream(idx, apply_unembed = True):
    # idx is a vector of model dim
    if apply_unembed:
        attrs = idx@model.W_U
    else:
        attrs = idx
    top_indices = torch.argsort(attrs, descending=True)[:20]
    return imshow(attrs[top_indices].unsqueeze(0), x = model.to_str_tokens(top_indices), title = "Top tokens")

In [7]:
# prompts = ["The city Delhi belongs to the country of", "The city Mumbai belongs to the country of", "The city Chennai belongs to the country of"
#            , "The city Bangalore belongs to the country of", "The city Bombay belongs to the country of"]
# answer = 'India'
# cities = ["Delhi", "Mumbai", "Chennai", "Bangalore", "Bombay"]
# wrong_answers = ["Delhi", "Mumbai", "Chennai", "Bangalore", "Bombay"]

torch.Size([6, 2, 768])

In [191]:
resid_accumlated, labels = cache.accumulated_resid(return_labels=True, apply_ln= True)
print(resid_accumlated.shape, labels)
resid_accumlated_in_answer_direction =  torch.einsum('abcd,bd->bac', resid_accumlated, answer_token_direction)

resid_accumlated_in_answer_direction.shape

torch.Size([13, 8, 8, 768]) ['0_pre', '1_pre', '2_pre', '3_pre', '4_pre', '5_pre', '6_pre', '7_pre', '8_pre', '9_pre', '10_pre', '11_pre', 'final_post']


torch.Size([8, 13, 8])

In [192]:
imshow(resid_accumlated_in_answer_direction.mean(dim=0).T, title = "Residuals in the direction of the answer token", x = labels, y = model.to_str_tokens(prompts[0]) ) 

In [138]:
print(resid_accumlated.shape)


plot_top_tokens_from_res_stream(resid_accumlated[-1, 2, -1])

torch.Size([13, 6, 8, 768])


In [195]:
line((resid_accumlated_in_answer_direction).mean(dim=0)[:,-1], x = labels, title = "Residuals in answer direction")

In [194]:
print(resid_accumlated_in_answer_direction.shape)

torch.Size([8, 13, 8])


In [21]:
line((heads_decomposed@answer_token_direction).mean(dim=1)[:,-1], x = labels, title = "Residuals in answer direction")

In [22]:
utils.get_act_name('attn', 9)

'blocks.9.attn.hook_pattern'

In [23]:
cache['blocks.9.attn.hook_pattern']

torch.Size([5, 12, 9, 9])

In [30]:
imshow(cache['blocks.9.attn.hook_pattern'].mean(dim=0)[8], title = "Attention pattern Layer 9 Head 8", x = model.to_str_tokens(prompts[0]), y = model.to_str_tokens(prompts[0]))

In [50]:
utils.get_act_name('post',0)

'blocks.0.mlp.hook_post'

In [49]:
cache['blocks.0.mlp.hook_post'].shape

torch.Size([5, 9, 3072])

In [65]:
(cache['blocks.0.mlp.hook_post'] > 0).reshape(-1).sum()/ (cache['blocks.0.mlp.hook_post'] > 0).reshape(-1).shape[0]

tensor(0.1281, device='mps:0')

In [73]:
def get_avg_activation_percentage(cache, layer_name, threshold):
    act = cache[layer_name]
    return (act > threshold).reshape(-1).sum()/ (act > threshold).reshape(-1).shape[0]

for layer in cache.keys():
    if 'mlp' in layer:
        print(layer, get_avg_activation_percentage(cache, layer, 0))
get_avg_activation_percentage(cache, 'blocks.0.mlp.hook_post', 1)*100

blocks.0.mlp.hook_pre tensor(0.1281, device='mps:0')
blocks.0.mlp.hook_post tensor(0.1281, device='mps:0')
blocks.0.hook_mlp_out tensor(0.4878, device='mps:0')
blocks.1.mlp.hook_pre tensor(0.2706, device='mps:0')
blocks.1.mlp.hook_post tensor(0.2706, device='mps:0')
blocks.1.hook_mlp_out tensor(0.4411, device='mps:0')
blocks.2.mlp.hook_pre tensor(0.1875, device='mps:0')
blocks.2.mlp.hook_post tensor(0.1875, device='mps:0')
blocks.2.hook_mlp_out tensor(0.4460, device='mps:0')
blocks.3.mlp.hook_pre tensor(0.1327, device='mps:0')
blocks.3.mlp.hook_post tensor(0.1327, device='mps:0')
blocks.3.hook_mlp_out tensor(0.4464, device='mps:0')
blocks.4.mlp.hook_pre tensor(0.1804, device='mps:0')
blocks.4.mlp.hook_post tensor(0.1804, device='mps:0')
blocks.4.hook_mlp_out tensor(0.4442, device='mps:0')
blocks.5.mlp.hook_pre tensor(0.1758, device='mps:0')
blocks.5.mlp.hook_post tensor(0.1758, device='mps:0')
blocks.5.hook_mlp_out tensor(0.4485, device='mps:0')
blocks.6.mlp.hook_pre tensor(0.1708, dev

tensor(0.6771, device='mps:0')

In [103]:
num_layers = sum('hook_mlp_out' in layer for layer in cache.keys())
ans = torch.empty(num_layers)


index = 0
for layer in cache.keys():
    if 'hook_mlp_out' in layer:
        print(cache[layer].shape)
        data =cache[layer]
        print(layer, (data@answer_token_direction).mean(dim=0)[-1])
        ans[index] = (data@answer_token_direction).mean(dim=0)[-1]
        index += 1
line(ans, x = range(num_layers), title = "Residuals in answer direction")

torch.Size([5, 9, 768])
blocks.0.hook_mlp_out tensor(1.2059, device='mps:0')
torch.Size([5, 9, 768])
blocks.1.hook_mlp_out tensor(-0.5043, device='mps:0')
torch.Size([5, 9, 768])
blocks.2.hook_mlp_out tensor(-2.6819, device='mps:0')
torch.Size([5, 9, 768])
blocks.3.hook_mlp_out tensor(1.0177, device='mps:0')
torch.Size([5, 9, 768])
blocks.4.hook_mlp_out tensor(-0.9052, device='mps:0')
torch.Size([5, 9, 768])
blocks.5.hook_mlp_out tensor(-4.4674, device='mps:0')
torch.Size([5, 9, 768])
blocks.6.hook_mlp_out tensor(4.5373, device='mps:0')
torch.Size([5, 9, 768])
blocks.7.hook_mlp_out tensor(0.2651, device='mps:0')
torch.Size([5, 9, 768])
blocks.8.hook_mlp_out tensor(3.7097, device='mps:0')
torch.Size([5, 9, 768])
blocks.9.hook_mlp_out tensor(0.9886, device='mps:0')
torch.Size([5, 9, 768])
blocks.10.hook_mlp_out tensor(20.0983, device='mps:0')
torch.Size([5, 9, 768])
blocks.11.hook_mlp_out tensor(-28.3539, device='mps:0')


In [104]:
mlp_contrib , labels  = cache.decompose_resid(return_labels=True, apply_ln= False, mode = 'mlp')
print(mlp_contrib.shape, len(labels))

torch.Size([14, 5, 9, 768]) 14


In [105]:
#Just scaled down of the previous plot.
line((mlp_contrib@answer_token_direction).mean(dim =1)[:,-1], title = "MLP contribution Layer 0", x = labels)