# Attention Head Visualization

In [1]:
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

import torch
import matplotlib.pyplot as plt

from model import AttentionOnlyTransformer
from generate_data import generate_rrt
from interactive_viewers import interactive_attention_viewer

device = torch.device(
    "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {device}")

torch.manual_seed(67)

Using device: cuda


<torch._C.Generator at 0x27bd12054d0>

In [2]:
print("Random sequence (seq_len=10):")
seq_len, data = generate_rrt(num_samples=1, length=80, vocab_size=256, seq_len=10)
data = data.to(device)
print(f"  Shape: {data.shape}, First 25 tokens: {data[0, :25].tolist()}")

print("\nCustom sequence [1, 2, 3, 4, 5]:")
seq_len, data = generate_rrt(num_samples=1, length=80, vocab_size=256, seq_len=5, custom_seq=[1, 2, 3, 4, 5])
data = data.to(device)
print(f"  Shape: {data.shape}, First 25 tokens: {data[0, :25].tolist()}")

Random sequence (seq_len=10):
  Shape: torch.Size([1, 80]), First 25 tokens: [120, 167, 140, 36, 246, 239, 245, 180, 149, 73, 120, 167, 140, 36, 246, 239, 245, 180, 149, 73, 120, 167, 140, 36, 246]

Custom sequence [1, 2, 3, 4, 5]:
  Shape: torch.Size([1, 80]), First 25 tokens: [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]


In [3]:
def load_model(model_path: Path, n_layers: int) -> AttentionOnlyTransformer:
    """Load a trained model from disk."""
    model = AttentionOnlyTransformer(
        vocab_size=256,
        d_model=64,
        n_layers=n_layers,
        n_heads=4,
        max_context_len=80,
    ).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model.eval()
    return model

# Load all 4 models
results_dir = Path.cwd().parent / "results"

models = {
    "1L_fixed": load_model(results_dir / "1L_fixed_model.pt", n_layers=1),
    "2L_fixed": load_model(results_dir / "2L_fixed_model.pt", n_layers=2),
    "1L_varied": load_model(results_dir / "1L_varied_model.pt", n_layers=1),
    "2L_varied_20000": load_model(results_dir / "2L_varied_model_20000.pt", n_layers=2),
    "2L_varied_60000": load_model(results_dir / "2L_varied_model_60000.pt", n_layers=2),
}

print("Loaded models:")
for name, model in models.items():
    print(f"  {name}: {model.n_layers} layer(s), {sum(p.numel() for p in model.parameters()):,} params")


Loaded models:
  1L_fixed: 1 layer(s), 49,152 params
  2L_fixed: 2 layer(s), 65,536 params
  1L_varied: 1 layer(s), 49,152 params
  2L_varied_20000: 2 layer(s), 65,536 params
  2L_varied_60000: 2 layer(s), 65,536 params


### Random sequence with specified prefix length

In [4]:
# Generate a random repeating sequence with seq_len=5
seq_len, sequence = generate_rrt(num_samples=1, length=80, vocab_size=256, seq_len=25)
sequence = sequence.to(device)

print(f"Sequence shape: {sequence.shape}")
print(f"First {seq_len * 2} tokens (shows one full repeat):")
print(sequence[0, :seq_len * 2].tolist())

interactive_attention_viewer(
    list(models.values()), 
    sequence,
    seq_len=seq_len,
    model_names=list(models.keys()),
    max_tokens=80,
)

Sequence shape: torch.Size([1, 80])
First 50 tokens (shows one full repeat):
[186, 102, 123, 81, 80, 154, 47, 228, 35, 202, 77, 228, 154, 120, 104, 123, 16, 74, 203, 253, 210, 106, 106, 219, 58, 186, 102, 123, 81, 80, 154, 47, 228, 35, 202, 77, 228, 154, 120, 104, 123, 16, 74, 203, 253, 210, 106, 106, 219, 58]


In [5]:
seq_len, sequence = generate_rrt(num_samples=1, length=80, vocab_size=256, seq_len=35)
sequence = sequence.to(device)

print(f"Sequence shape: {sequence.shape}")
print(f"First {seq_len * 2} tokens (shows one full repeat):")
print(sequence[0, :seq_len * 2].tolist())

interactive_attention_viewer(
    list(models.values()), 
    sequence,
    seq_len=seq_len,
    model_names=list(models.keys()),
    max_tokens=80,
)

Sequence shape: torch.Size([1, 80])
First 70 tokens (shows one full repeat):
[188, 65, 70, 236, 80, 55, 251, 233, 61, 4, 103, 206, 154, 227, 84, 44, 65, 170, 155, 136, 93, 53, 142, 99, 127, 228, 154, 197, 75, 211, 109, 90, 243, 95, 54, 188, 65, 70, 236, 80, 55, 251, 233, 61, 4, 103, 206, 154, 227, 84, 44, 65, 170, 155, 136, 93, 53, 142, 99, 127, 228, 154, 197, 75, 211, 109, 90, 243, 95, 54]


### Custom sequence

In [7]:
my_tokens = [i for i in range(1, 29)]
my_tokens[14] = 5
my_tokens[23] = 2
seq_len, sequence = generate_rrt(num_samples=1, length=80, vocab_size=256, seq_len=len(my_tokens), custom_seq=my_tokens)

sequence = sequence.to(device)

print(f"Sequence shape: {sequence.shape}")
print(f"First {seq_len * 2} tokens:")
print(sequence[0, :seq_len * 2].tolist())

interactive_attention_viewer(
    list(models.values()),  
    sequence,
    seq_len=seq_len,
    model_names=list(models.keys()),
    max_tokens=80,
)

Sequence shape: torch.Size([1, 80])
First 56 tokens:
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 5, 16, 17, 18, 19, 20, 21, 22, 23, 2, 25, 26, 27, 28, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 5, 16, 17, 18, 19, 20, 21, 22, 23, 2, 25, 26, 27, 28]
