In [3]:
import torch

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f938b7a02b0>

In [5]:
# subject_names = ["John", " Mary"]

In [6]:
# subject_tokens = [model.to_single_token(x) for x in subject_names]

In [7]:
# subject_tokens

##### Example 1

In [8]:
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens import utils

In [9]:
prompt = "John told Mary: 'Persistence is all you need.' Mary replied back to "

In [10]:
receiver_hook_name = f"blocks.{12-1}.hook_resid_post"

In [11]:
model = HookedTransformer.from_pretrained("gpt2-small")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [12]:
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads

In [13]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [14]:
n_layers, n_heads, receiver_hook_name

(12, 12, 'blocks.11.hook_resid_post')

In [15]:
prompt

"John told Mary: 'Persistence is all you need.' Mary replied back to "

Calculate the importance of the computational path **from head 6 at layer 9 to the final layer (`receiver_hook_name`)** in predicting the indirect object token using activation patching and logit difference. Explain your code

In [16]:
corrupted_prompt = "Laura told Katie: 'Persistence is all you need'. Bread replied back to "

In [17]:
clean_tokens = model.to_tokens(prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [18]:
_, clean_activations = model.run_with_cache(clean_tokens)
_, corrupted_activations = model.run_with_cache(corrupted_tokens)

The attention head to be patched is defined here.

In [19]:
head_idx, layer_idx = 6, 9
hook_name = utils.get_act_name("attn", layer_idx)

We also extract the activations of this head for the corrupted input.

In [20]:
corrupted_head_activations = corrupted_activations[hook_name][0, head_idx, :, :]
corrupted_head_activations.shape

torch.Size([18, 18])

This function replaces the attention head's activations on the clean input with those from the corrupted input. The function is then added as a hook to the model, and the model is run on the clean input with this modification. This simulates the effect of the attention head's outputs on the corrupted input, but within the context of the clean input.

In [21]:
def patch_head_activation(activations, hook):
    activations[0, head_idx, :, :] = corrupted_head_activations
    return activations

In [22]:
model.add_hook(hook_name, patch_head_activation)
_, patched_activations = model.run_with_cache(clean_tokens)

In [23]:
def extract_receiver_activations(activations, hook_name):
    return activations[hook_name]

In [24]:
receiver_activations = extract_receiver_activations(patched_activations, receiver_hook_name)

This part of the code is patching the receiver nodes (which, in this case, are the final layer activations) to the state they would be in after the patching of the chosen head. This is done to isolate the effect of the chosen head from the downstream computation of the model

In [25]:
def patch_receiver_activations(activations, hook):
    activations = receiver_activations
    return activations

In [26]:
model.reset_hooks()
patched_logits = model.run_with_hooks(
    clean_tokens,
    fwd_hooks=[(receiver_hook_name, patch_receiver_activations)]
)

In [27]:
def compute_logit_difference(clean_activations, patched_activations, target_token):
    clean_target_logit = clean_activations[:, -1, target_token]
    patched_target_logit = patched_activations[:, -1, target_token]
    return patched_target_logit - clean_target_logit

Next, we calculate the difference in logits between the `clean_tokens` and the logits of the `clean_tokens` with the patched computational path.

In [28]:
model.reset_hooks()
clean_logits = model(clean_tokens)

In [29]:
target_token = model.to_single_token("John")

In [30]:
output = compute_logit_difference(clean_logits, patched_logits, target_token)

In [31]:
output

tensor([-0.0321])