In [16]:
%reload_ext autoreload
%autoreload 2

import torch

torch.manual_seed(42)


<torch._C.Generator at 0x7996f81d4550>

In [17]:
from learning.gpt2.model import GPT2, PretrainedName


device = torch.device("cuda")
model, _pretrained_model = GPT2.from_pretrained(
    PretrainedName.GPT2_SMALL, device=device
)


In [18]:
from learning.gpt2.ioi_circuit_analyzer import IoiCircuitAnalyzer
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
analyzer = IoiCircuitAnalyzer(model, tokenizer, device)


In [19]:
from pdb import run


def run_case(analyzer: IoiCircuitAnalyzer, model: GPT2, case: str, k: int):
    print(case)
    result = analyzer.topk_logits(case, k)
    for i in range(k):
        indices = [int(result.top_indices[i])]
        print(f"{result.top_probs[i]:.2f} {tokenizer.decode(indices)}")


def run_test(analyzer: IoiCircuitAnalyzer, model: GPT2):
    k = 3
    # From the paper: https://arxiv.org/abs/2211.00593
    cases = [
        "When Mary and John went to the store, John gave a drink to",
        "When Vincent and Vanessa went to the park, Vincent gave a leaf to",
        "When Vincent and Vanessa went to the park, Vanessa gave a leaf to",
        "Mary and John went to the store. John gave a drink to",
        "Mary and John went to the store; John gave a drink to",
        "Mary and John went to the store! John gave a drink to",
        "Mary and John went to the store. Mary gave a drink to",
        "Mary and John went to the store; Mary gave a drink to",
        "Mary and John went to the store! Mary gave a drink to",
    ]

    for case in cases:
        run_case(analyzer, model, case, k)


print("-" * 80)
run_test(analyzer, model)

--------------------------------------------------------------------------------
When Mary and John went to the store, John gave a drink to
0.45  Mary
0.21  them
0.07  the
When Vincent and Vanessa went to the park, Vincent gave a leaf to
0.51  Vanessa
0.12  the
0.05  a
When Vincent and Vanessa went to the park, Vanessa gave a leaf to
0.54  Vincent
0.09  the
0.03  her
Mary and John went to the store. John gave a drink to
0.31  them
0.15  John
0.13  the
Mary and John went to the store; John gave a drink to
0.33  them
0.15  John
0.14  the
Mary and John went to the store! John gave a drink to
0.27  them
0.13  the
0.10  John
Mary and John went to the store. Mary gave a drink to
0.46  John
0.14  them
0.11  the
Mary and John went to the store; Mary gave a drink to
0.29  John
0.21  them
0.13  the
Mary and John went to the store! Mary gave a drink to
0.40  John
0.14  the
0.12  them


In [20]:
print(model)


def test_head(analyzer: IoiCircuitAnalyzer, model: GPT2, block_idx, head_idx):
    print("-" * 80)
    print(f"Testing head [{block_idx}][{head_idx}]")

    # A: x_new forward pass.
    # All heads are frozen to the x_new computations
    model.set_capture_output(True)
    model.set_use_frozen_output(False)
    run_case(
        analyzer, model, "When Mike and Tom went to the store, Rise gave a drink to", 3
    )
    captured_output = model.blocks[block_idx].attention.heads[head_idx].frozen_output

    # B: x_ori forward pass.
    # All heads are frozen to the x_ori computations
    model.set_capture_output(True)
    model.set_use_frozen_output(False)
    run_case(
        analyzer, model, "When Mary and John went to the store, John gave a drink to", 3
    )

    # C: x_ori forward pass with `h` (head `head_idx` of block `block_idx`) patched from A
    # All heads except `h` are frozen to the x_ori computations
    # `h` is frozen to the x_new computations
    model.set_capture_output(False)
    model.set_use_frozen_output(True)
    model.blocks[block_idx].attention.heads[head_idx].frozen_output = captured_output
    run_case(
        analyzer, model, "When Mary and John went to the store, John gave a drink to", 3
    )


for head_idx in range(model.config.num_heads):
    test_head(analyzer, model, 11, head_idx)

GPT2(
  (embedding): Embedding(50257, 768)
  (positional_embedding): Embedding(1024, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks_module): ModuleList(
    (0-11): 12 x Block(
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention): MultiHeadAttention(
        (heads_module): ModuleList(
          (0-11): 12 x AttentionHead(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (projection): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GE