In [1]:
from transformer_lens import HookedTransformer, ActivationCache
import torch
import matplotlib.pyplot as plt
import os
import plotly.express as px
import transformer_lens.utils as utils
from einops import einsum
from typing import List, Optional, Union
from jaxtyping import Float
from circuitsvis.attention import attention_heads


torch.set_grad_enabled(False)

model = HookedTransformer.from_pretrained("meta-llama/Meta-Llama-3-8B", device="cuda")

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()



Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer


In [2]:
A = """
(28, 59)
(86, 175)
(13, 29)
(55, 113)
(84, 171)
(66, 135)
(85, 173)
(27, 57)
(15, 33)
(94, 191)
(37, 77)
(14, 31)
(42, """

correct_answer = "87"
correct_answer_token = model.to_tokens(correct_answer, prepend_bos=False).squeeze(1)
# this gives us a particular index. i don't know if i have to convert it into a distribution.
# still ez though

tokens = model.to_tokens(A, prepend_bos=True)
text = model.tokenizer.convert_ids_to_tokens(tokens.squeeze(0))
print(text)

['<|begin_of_text|>', 'Ċ', '(', '28', ',', 'Ġ', '59', ')Ċ', '(', '86', ',', 'Ġ', '175', ')Ċ', '(', '13', ',', 'Ġ', '29', ')Ċ', '(', '37', ',', 'Ġ', '77', ')Ċ', '(', '55', ',', 'Ġ', '113', ')Ċ', '(', '84', ',', 'Ġ', '171', ')Ċ', '(', '66', ',', 'Ġ', '135', ')Ċ', '(', '85', ',', 'Ġ', '173', ')Ċ', '(', '27', ',', 'Ġ', '57', ')Ċ', '(', '15', ',', 'Ġ', '33', ')Ċ', '(', '94', ',', 'Ġ', '191', ')Ċ', '(', '37', ',', 'Ġ', '77', ')Ċ', '(', '14', ',', 'Ġ', '31', ')Ċ', '(', '42', ',', 'Ġ']


In [3]:
logits, cache = model.run_with_cache(A)

In [4]:
most_recent = cache["resid_pre", 15][:, -1, :]
print(most_recent.shape) # (1, 4096)

heads, labels = cache.stack_head_results(pos_slice=-1, apply_ln=True, return_labels=True)
print(heads.shape) # (1024, 1, 4096)

# take the dot product between the most recent and the heads
# this will give us a score for each head
# Reshape most_recent to (4096, 1) for matrix multiplication
most_recent_reshaped = most_recent.transpose(0, 1)

# Reshape heads to (1024, 4096) for matrix multiplication
heads_reshaped = heads.squeeze(1)

head_scores = torch.matmul(heads_reshaped, most_recent_reshaped)

print(head_scores.shape)
print(labels)

# labels are off the form "LNHM", where N and M are layer and the head number
# plot on a 32 x 32 grid

head_scores = head_scores.squeeze(1)[:512]
copy_of_head_scores = head_scores.reshape(16, 32)

imshow(copy_of_head_scores, title="Head scores", labels={"x": "Head", "y": "Layer"})



torch.Size([1, 4096])
Tried to stack head results when they weren't cached. Computing head results now
torch.Size([1024, 1, 4096])
torch.Size([1024, 1])
['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4', 'L0H5', 'L0H6', 'L0H7', 'L0H8', 'L0H9', 'L0H10', 'L0H11', 'L0H12', 'L0H13', 'L0H14', 'L0H15', 'L0H16', 'L0H17', 'L0H18', 'L0H19', 'L0H20', 'L0H21', 'L0H22', 'L0H23', 'L0H24', 'L0H25', 'L0H26', 'L0H27', 'L0H28', 'L0H29', 'L0H30', 'L0H31', 'L1H0', 'L1H1', 'L1H2', 'L1H3', 'L1H4', 'L1H5', 'L1H6', 'L1H7', 'L1H8', 'L1H9', 'L1H10', 'L1H11', 'L1H12', 'L1H13', 'L1H14', 'L1H15', 'L1H16', 'L1H17', 'L1H18', 'L1H19', 'L1H20', 'L1H21', 'L1H22', 'L1H23', 'L1H24', 'L1H25', 'L1H26', 'L1H27', 'L1H28', 'L1H29', 'L1H30', 'L1H31', 'L2H0', 'L2H1', 'L2H2', 'L2H3', 'L2H4', 'L2H5', 'L2H6', 'L2H7', 'L2H8', 'L2H9', 'L2H10', 'L2H11', 'L2H12', 'L2H13', 'L2H14', 'L2H15', 'L2H16', 'L2H17', 'L2H18', 'L2H19', 'L2H20', 'L2H21', 'L2H22', 'L2H23', 'L2H24', 'L2H25', 'L2H26', 'L2H27', 'L2H28', 'L2H29', 'L2H30', 'L2H31', 'L3H0', 'L3H

In [5]:
# layer 10 head 5
# layer 10 head 7
# layer 13 head 6
# there are other good heads here but we first see some nice behavior with patching in layer 10
# then more nice behavior on layer 13 apparently

def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [6]:
from IPython.display import HTML, IFrame

top_k = 5
copy_of_head_scores = copy_of_head_scores.flatten()
top_k_indices = torch.topk(copy_of_head_scores, top_k).indices

positive_html = visualize_attention_patterns(
    top_k_indices,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

HTML(positive_html)

In [7]:
# now we need to display all activations at L9 (this is going to suck isn't it)
# we need to first create a range between [32 * 9, 32 * 10]

l9_head_indexes = list(range(32 * 9, 32 * 10))
l9_html = visualize_attention_patterns(
    l9_head_indexes,
    cache,
    tokens[0],
    f"Layer 9 Attention Patterns",
)

HTML(l9_html)


In [None]:
# L9H21 is promising as the first part of the circuit
# L9H23
# L9H25
# L9H27
# L9H30
# L9H11