In [189]:
%reload_ext autoreload
%autoreload 2

import random
import torch

random.seed(42)
torch.manual_seed(42)


<torch._C.Generator at 0x721f105d44f0>

In [190]:
from learning.gpt2.model import GPT2, PretrainedName


device = torch.device("cuda")
model, _pretrained_model = GPT2.from_pretrained(
    PretrainedName.GPT2_SMALL, device=device
)


In [191]:
from data.names_data_source import NamesDataSource
from learning.gpt2.ioi_circuit_analyzer import NameSampler, PromptTemplate
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")


names_data_source = NamesDataSource.load(
    data_folder="../datasets/names",
)

# Get all English names
all_english_names = names_data_source.country_idx_to_names[4]


def filter_single_token_names(names: list[str]) -> list[str]:
    all_indices = tokenizer.encode_batch(names)
    single_token_indices = [indices for indices in all_indices if len(indices) == 1]
    return tokenizer.decode_batch(single_token_indices)


# Keep only the names that are single tokens (without space padding). E.g. "John"
filtered_names = filter_single_token_names(all_english_names)

# Add space padding to the names, and filter again
padded_filtered_names = [" " + name for name in filtered_names]
single_token_names = filter_single_token_names(padded_filtered_names)
single_token_names = [name[1:] for name in single_token_names]

# The final names are guaranteed to be single tokens, with or without space padding. E.g. " John" and "John"
print(len(single_token_names))
print(single_token_names)

name_sampler = NameSampler(single_token_names)
prompt_template = PromptTemplate(
    template="When {s1} and {s2} went to the store, {s3} gave a drink to",
    name_sampler=name_sampler,
)

print("-" * 80)
for _ in range(10):
    print(prompt_template.sample_abc())

print("-" * 80)
for _ in range(10):
    print(prompt_template.sample_aba())

print("-" * 80)
for _ in range(10):
    print(prompt_template.sample_abb())


254
['Adam', 'Adams', 'Albert', 'Alexander', 'Ali', 'Allen', 'Anderson', 'Andrew', 'Anthony', 'Arthur', 'Austin', 'Ball', 'Bear', 'Beck', 'Beer', 'Bell', 'Berry', 'Best', 'Bill', 'Bird', 'Black', 'Blake', 'Bloom', 'Bob', 'Bone', 'Brain', 'Bright', 'Brook', 'Brown', 'Bruce', 'Bull', 'Burn', 'Bush', 'Button', 'Carter', 'Chan', 'Chance', 'Charge', 'Charles', 'Child', 'Church', 'Clark', 'Close', 'Cole', 'Coll', 'Collins', 'Connor', 'Cook', 'Core', 'Court', 'Craig', 'Crew', 'Cross', 'Dallas', 'Daniel', 'David', 'Davis', 'Day', 'Dean', 'Diamond', 'Dick', 'Donald', 'Down', 'Driver', 'East', 'Edge', 'England', 'English', 'Fall', 'Field', 'Fish', 'Ford', 'Fox', 'Frame', 'France', 'French', 'Friend', 'Gall', 'Gene', 'George', 'Given', 'Glass', 'Gold', 'Good', 'Gordon', 'Graham', 'Grant', 'Gray', 'Green', 'Grey', 'Guest', 'Hall', 'Hamilton', 'Hand', 'Harris', 'Harry', 'Hart', 'Head', 'Henry', 'Hill', 'Hope', 'Houston', 'Howard', 'Hunt', 'Hunter', 'Ireland', 'Islam', 'Jackson', 'Jacob', 'James', '

In [192]:
from learning.gpt2.ioi_circuit_analyzer import IoiCircuitAnalyzer

analyzer = IoiCircuitAnalyzer(model, tokenizer, prompt_template, device)


In [193]:
def run_case(analyzer: IoiCircuitAnalyzer, model: GPT2, case: str, k: int):
    print(case)
    result = analyzer.topk_probs(case, k)
    for i in range(k):
        indices = [int(result.top_indices[i])]
        print(f"{result.top_probs[i]:.2f} {tokenizer.decode(indices)}")


def run_test(analyzer: IoiCircuitAnalyzer, model: GPT2):
    k = 3
    templates: list[PromptTemplate] = [
        PromptTemplate(
            template="When {s1} and {s2} went to the store, {s3} gave a drink to",
            name_sampler=name_sampler,
        ),
        PromptTemplate(
            template="When {s1} and {s2} went to the park, {s3} gave a leaf to",
            name_sampler=name_sampler,
        ),
        PromptTemplate(
            template="Yesterday {s1} and {s2} went to the store. {s3} gave a drink to",
            name_sampler=name_sampler,
        ),
    ]

    print("ABC " + "-" * 80)
    for template in templates:
        for _ in range(3):
            run_case(analyzer, model, template.sample_abc(), k)

    print("ABA " + "-" * 80)
    for template in templates:
        for _ in range(3):
            run_case(analyzer, model, template.sample_aba(), k)

    print("ABB " + "-" * 80)
    for template in templates:
        for _ in range(3):
            run_case(analyzer, model, template.sample_abb(), k)


run_test(analyzer, model)

ABC --------------------------------------------------------------------------------
When Mac and French went to the store, Stack gave a drink to
0.19  the
0.10  them
0.06  a
When Pack and Number went to the store, Thom gave a drink to
0.16  the
0.08  a
0.07  his
When Strong and Hamilton went to the store, Michael gave a drink to
0.16  the
0.08  a
0.07  them
When Court and Reader went to the park, Best gave a leaf to
0.30  the
0.05  a
0.04  Judge
When Ball and Pierre went to the park, Dean gave a leaf to
0.13  the
0.07  Pierre
0.06  them
When Said and France went to the park, York gave a leaf to
0.26  the
0.06  his
0.05  a
Yesterday Black and Strange went to the store. Diamond gave a drink to
0.20  the
0.08  me
0.08  them
Yesterday Sullivan and Brain went to the store. Head gave a drink to
0.16  the
0.10  Sullivan
0.09  them
Yesterday Ford and Judge went to the store. Park gave a drink to
0.22  the
0.13  them
0.09  Judge
ABA -------------------------------------------------------------

In [194]:
print(model.config)
print(model)

# Phase A: baseline forward pass.
# All heads are frozen to the baseline computations
batch_size = 2**7
analyzer.capture_baseline_output(batch_size)


ModelConfig(embedding_size=768, num_heads=12, num_blocks=12, vocab_size=50257, sequence_length=1024, feed_forward_expansion_factor=4, dropout=0.1, device=device(type='cuda'))
GPT2(
  (embedding): Embedding(50257, 768)
  (positional_embedding): Embedding(1024, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks_module): ModuleList(
    (0-11): 12 x Block(
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention): MultiHeadAttention(
        (heads_module): ModuleList(
          (0-11): 12 x AttentionHead(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (projection): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layer_norm2): Laye

In [195]:
from learning.gpt2.ioi_circuit_analyzer import HeadAnalysisResult


results = []
for head_idx in range(model.config.num_heads):
    result: HeadAnalysisResult = analyzer.analyze_head(
        11, head_idx, "Mary", "John", "John"
    )
    results.append(result)

print(prompt_template.from_abb("Mary", "John"))

print("\n" + "=" * 80)
print("SUMMARY: Most impactful heads (by KL divergence)")
print("=" * 80)
sorted_results = sorted(
    enumerate(results), key=lambda x: x[1].kl_divergence, reverse=True
)
for i, (head_idx, result) in enumerate(sorted_results):
    print(
        f"#{i + 1} \tHead [11][{head_idx}] \t"
        f"KL: {result.kl_divergence:.4f}\t"
        f"TV: {result.total_variation:.4f}\t"
        f"L2: {result.l2_distance:.4f}\t"
        f"Mary prob: {result.s1_prob_change:.4f}\t"
        f"John prob: {result.s2_prob_change:.4f}\t"
        f"Logit diff: {result.logit_diff_change:.4f}\t"
    )

When Mary and John went to the store, John gave a drink to

SUMMARY: Most impactful heads (by KL divergence)
#1 	Head [11][10] 	KL: 0.1395	TV: 0.2426	L2: 0.2691	Mary prob: -0.2424	John prob: 0.0054	Logit diff: -0.5275	
#2 	Head [11][1] 	KL: 0.0753	TV: 0.1808	L2: 0.1767	Mary prob: 0.1469	John prob: 0.0324	Logit diff: -0.3763	
#3 	Head [11][2] 	KL: 0.0360	TV: 0.1259	L2: 0.1393	Mary prob: 0.1255	John prob: -0.0022	Logit diff: 0.3666	
#4 	Head [11][3] 	KL: 0.0185	TV: 0.0921	L2: 0.0999	Mary prob: 0.0895	John prob: -0.0087	Logit diff: 0.3589	
#5 	Head [11][6] 	KL: 0.0062	TV: 0.0533	L2: 0.0588	Mary prob: 0.0530	John prob: -0.0028	Logit diff: 0.1713	
#6 	Head [11][8] 	KL: 0.0046	TV: 0.0326	L2: 0.0243	Mary prob: -0.0149	John prob: 0.0057	Logit diff: -0.1336	
#7 	Head [11][11] 	KL: 0.0032	TV: 0.0388	L2: 0.0398	Mary prob: 0.0348	John prob: 0.0033	Logit diff: 0.0242	
#8 	Head [11][0] 	KL: 0.0008	TV: 0.0160	L2: 0.0122	Mary prob: -0.0084	John prob: -0.0030	Logit diff: 0.0299	
#9 	Head [11][9] 	KL: 0