In [1]:

from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

  from .autonotebook import tqdm as notebook_tqdm


Disabled automatic differentiation


In [175]:
model = HookedTransformer.from_pretrained("gpt2-small")
device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [203]:
prompt = "Mary and John went to the store. Mary then gave a toy to"
answer = "John"
corrupted_answer = "Mary"
answer_token = model.to_single_token(f" {answer}")
wrong_token = model.to_single_token(f" {corrupted_answer}")
utils.test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' store', '.', ' Mary', ' then', ' gave', ' a', ' toy', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 17.93 Prob: 62.24% Token: | John|
Top 1th token. Logit: 15.99 Prob:  8.92% Token: | the|
Top 2th token. Logit: 15.60 Prob:  6.01% Token: | Mary|
Top 3th token. Logit: 15.11 Prob:  3.71% Token: | her|
Top 4th token. Logit: 14.85 Prob:  2.86% Token: | them|
Top 5th token. Logit: 14.60 Prob:  2.23% Token: | a|
Top 6th token. Logit: 13.11 Prob:  0.50% Token: | one|
Top 7th token. Logit: 13.01 Prob:  0.45% Token: | their|
Top 8th token. Logit: 13.01 Prob:  0.45% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.44% Token: | Joseph|


In [207]:
logits, cache = model.run_with_cache(prompt, remove_batch_dim=True)

def get_score(logits):
    return (logits[0][-1][answer_token] - logits[0][-1][wrong_token]).item()
base_difference = get_score(logits)

2.3382863998413086

In [214]:
accumulated_resid , labels = cache.accumulated_resid(apply_ln= True, return_labels=True)
print(accumulated_resid.shape, len(labels))

accumulated_resid = accumulated_resid[:,-1,:]@model.W_U

line(accumulated_resid[:, answer_token] - accumulated_resid[:, wrong_token], title="Accumulated Residuals", labels={"y": "Residuals", "x": "Layer"})

torch.Size([13, 15, 768]) 13


In [210]:


def layer_ablation_hook(value, **kwargs):
    value[:,:,:,:] = 0.
    return value


values = []
for layer in range(model.cfg.n_layers):
    layer_to_abalate = layer
    logits = model.run_with_hooks(prompt, fwd_hooks=[(
            utils.get_act_name("attn", layer_to_abalate), 
            layer_ablation_hook
            )])
    values.append(get_score(logits))

line(values, labels=range(model.cfg.n_layers), title="Layer ablation")

# %%
    

In [216]:
layer_data, labels = cache.decompose_resid(apply_ln=True, return_labels=True, incl_embeds=False)
print(layer_data.shape, len(labels))

torch.Size([24, 15, 768]) 24


In [217]:


def layer_ablation_hook(value, **kwargs):
    value[:,:,:,:] = 0.
    return value


values = []
x = []
for layer in range(model.cfg.n_layers):
    layer_to_abalate = f"blocks.{layer}.hook_attn_out"
    logits = model.run_with_hooks(prompt, fwd_hooks=[(
            layer_to_abalate, 
            layer_ablation_hook
            )])
    values.append(get_score(logits))
    x.append(layer_to_abalate)

    layer_to_abalate = f"blocks.{layer}.hook_mlp_out"

    logits = model.run_with_hooks(prompt, fwd_hooks=[(
            layer_to_abalate, 
            layer_ablation_hook
            )])
    values.append(get_score(logits))



line(values, labels=range(model.cfg.n_layers), title="Layer ablation")

# %%
    

['0_attn_out',
 '0_mlp_out',
 '1_attn_out',
 '1_mlp_out',
 '2_attn_out',
 '2_mlp_out',
 '3_attn_out',
 '3_mlp_out',
 '4_attn_out',
 '4_mlp_out',
 '5_attn_out',
 '5_mlp_out',
 '6_attn_out',
 '6_mlp_out',
 '7_attn_out',
 '7_mlp_out',
 '8_attn_out',
 '8_mlp_out',
 '9_attn_out',
 '9_mlp_out',
 '10_attn_out',
 '10_mlp_out',
 '11_attn_out',
 '11_mlp_out']

In [218]:
cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'block

In [222]:
utils.get_act_name("attn0", 'out')

'blocks.0.attn.hook_pattern'

In [172]:
blocks.0.hook_attn_out , blocks.0.hook_mlp_out

['embed',
 'pos_embed',
 '0_mlp_out',
 '1_mlp_out',
 '2_mlp_out',
 '3_mlp_out',
 '4_mlp_out',
 '5_mlp_out',
 '6_mlp_out',
 '7_mlp_out',
 '8_mlp_out',
 '9_mlp_out',
 '10_mlp_out',
 '11_mlp_out']