In [1]:
import torch
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f9faa014520>

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
from transformer_lens.utils import get_act_name

##### Example 1

In [7]:
clean_prompts = [
    "When John and Mary went to the shops, John gave the bag to",
    "When Tom and James went to the park, James gave the ball to"
]

In [33]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [34]:
clean_prompts

['When John and Mary went to the shops, John gave the bag to',
 'When Tom and James went to the park, James gave the ball to']

Create an IOI metric that works as bellow

In [35]:
corrupted_prompts = [
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the part, Tom gave the ball to"
]

In [36]:
clean_tokens = model.to_tokens(clean_prompts, prepend_bos=True)
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
clean_tokens.shape, corrupted_tokens.shape

(torch.Size([2, 15]), torch.Size([2, 15]))

In [37]:
correct_tokens = model.to_tokens(["Mary Tom"], prepend_bos=False)
incorrect_tokens = model.to_tokens(["John James"], prepend_bos=False)
correct_tokens, incorrect_tokens = correct_tokens[0], incorrect_tokens[0] # ignore
correct_tokens, incorrect_tokens

(tensor([24119,  4186]), tensor([7554, 3700]))

In [38]:
clean_logits, _ = model.run_with_cache(clean_tokens)
corrupted_logits, _ = model.run_with_cache(corrupted_tokens)
clean_logits.shape, corrupted_logits.shape

(torch.Size([2, 15, 50257]), torch.Size([2, 15, 50257]))

In [39]:
def compute_average_logit_difference(logits, correct_tokens, incorrect_tokens):
    final_logits = logits[:, -1, :]
    correct_logits = final_logits[:, correct_tokens]
    incorrect_logits = final_logits[:, incorrect_tokens]
    logit_diff = correct_logits - incorrect_logits
    return logit_diff.mean()

In [40]:
clean_logit_difference = compute_average_logit_difference(clean_logits, correct_tokens, incorrect_tokens)
corrupted_logit_difference = compute_average_logit_difference(corrupted_logits, correct_tokens, incorrect_tokens)
clean_logit_difference, corrupted_logit_difference

(tensor(0.8275), tensor(-2.3553))

In [41]:
def compute_ioi_metric(logits):
    patched_logit_difference = compute_average_logit_difference(logits, correct_tokens, incorrect_tokens)
    return (patched_logit_difference - corrupted_logit_difference) / (clean_logit_difference - corrupted_logit_difference)

In [42]:
compute_ioi_metric(clean_logits)

tensor(1.)

In [43]:
compute_ioi_metric(corrupted_logits)

tensor(0.)