<a href="https://colab.research.google.com/github/tinuademargaret/Capital-Letters-Circuit/blob/main/Learning_Capital_Letters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports


In [None]:
# Detect if we're running in Google Colab
import os

try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa
    os.system("pip install git+https://github.com/ArthurConmy/Automatic-Circuit-Discovery.git@d89f7fa9cbd095202f3940c889cb7c6bf5a9b516")

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

In [None]:
from functools import partial
from typing import List, Optional, Union
from time import ctime
from subprocess import call

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
# from easy_transformer import EasyTransformer
# from easy_transformer.utils_circuit_discovery import (
#     evaluate_circuit,
#     patch_all,
#     direct_path_patching,
#     logit_diff_io_s,
#     Circuit,
#     logit_diff_from_logits,
#     get_datasets,
# )

In [None]:
!pip install torchtyping

Collecting torchtyping
  Using cached torchtyping-0.1.4-py3-none-any.whl (17 kB)
Installing collected packages: torchtyping
Successfully installed torchtyping-0.1.4


In [None]:
from torchtyping import TensorType as TT

In [None]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


In [None]:
file_prefix = "archive/" if os.path.exists("archive") else ""

## Helper Functions

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
line(np.arange(5))

## Load the Model

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


## Task

In [None]:
example_prompt = "."
example_answer = "He"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
"""
Pronouns
{name} is a really great friend.,
{name} is such a good cook.",
{name} is a very good athlete.",
{name} is a really nice person.",
{name} is such a funny person."

Proper Nouns
The capital of Canada is Ottawa while the capital of France is
The white house is in Washington, while the Louvre is in

Other
P.S.
n.b.

Exceptions
esther@gmail.com
google.com
.py, .txt
list.append()
"""

Tokenized prompt: ['<|endoftext|>', '.']
Tokenized answer: [' He']


Top 0th token. Logit: 13.91 Prob: 47.72% Token: |
|
Top 1th token. Logit: 12.44 Prob: 11.03% Token: | .|
Top 2th token. Logit: 11.65 Prob:  4.99% Token: |@|
Top 3th token. Logit: 11.05 Prob:  2.75% Token: |<|endoftext|>|
Top 4th token. Logit: 10.08 Prob:  1.03% Token: | The|
Top 5th token. Logit: 10.06 Prob:  1.02% Token: |com|
Top 6th token. Logit:  9.66 Prob:  0.68% Token: |

|
Top 7th token. Logit:  9.27 Prob:  0.46% Token: |NET|
Top 8th token. Logit:  9.24 Prob:  0.45% Token: | (|
Top 9th token. Logit:  9.24 Prob:  0.45% Token: | This|


'\nPronouns\n{name} is a really great friend.,\n{name} is such a good cook.",\n{name} is a very good athlete.",\n{name} is a really nice person.",\n{name} is such a funny person."\n\nProper Nouns\nThe capital of Canada is Ottawa while the capital of France is\nThe white house is in Washington, while the Louvre is in\n\nOther\nP.S.\nn.b.\n\nExceptions\nesther@gmail.com\ngoogle.com\n.py, .txt\nlist.append()\n'

## Create Dataset

In [None]:
model.to_str_tokens(["Sarah"])

[['<|endoftext|>', 'Sarah']]

In [None]:
templates = [
    "{name} is a really great friend.",
    "{name} is such a good cook.",
    "{name} is a very good athlete.",
    "{name} is a really nice person.",
    "{name} is such a funny person."
    ]

male_names = [
    "John",
    "David",
    "Mark",
    "Paul",
    "Ryan",
    "Gary",
    "Jack",
    "Sean",
    "Carl",
    "Joe",
]
female_names = [
    "Mary",
    "Lisa",
    "Anna",
    "Sarah",
    "Amy",
    "Jane",
    "Joy",
    "Susan",
    "Victoria",
    "Laura"
]

sentences = []
answers = []
answer_tokens = []
corrects = []
wrongs = []

male_responses = [' He', ' he']
female_responses = [' She', ' she']

count = 0

for name in male_names:
    for template in templates:
        cur_sentence = template.format(name = name)
        sentences.append(cur_sentence)
        answers.append((male_responses[0], male_responses[1]))
        corrects.append(male_responses[0])
        wrongs.append(male_responses[1])
        answer_tokens.append([model.to_single_token(male_responses[0]), model.to_single_token(male_responses[1])])

for name in female_names:
    for template in templates:
        cur_sentence = template.format(name = name)
        sentences.append(cur_sentence)
        answers.append((female_responses[0], female_responses[1]))
        corrects.append(female_responses[0])
        wrongs.append(female_responses[1])
        answer_tokens.append([model.to_single_token(female_responses[0]), model.to_single_token(female_responses[1])])


answer_tokens = torch.tensor(answer_tokens).cuda()
batch_size = len(sentences)
# print(sentences)
# print(answers)
# print(answer_tokens)

In [None]:
tokens = model.to_tokens(sentences, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([7.2793, 7.5363, 7.4461, 6.7807, 7.1888, 7.2196, 7.6630, 7.3835, 6.8099,
        7.2605, 7.1385, 7.3482, 7.2831, 6.6183, 6.9629, 7.7064, 7.7734, 7.7242,
        7.0695, 7.4893, 7.6223, 7.8174, 7.5551, 7.0627, 7.4077, 7.3439, 7.6237,
        7.5325, 6.8741, 7.1212, 7.3446, 7.4451, 7.3785, 6.9057, 7.0899, 7.3435,
        7.4247, 7.3849, 6.7439, 6.8749, 7.1745, 7.3281, 7.2700, 6.6596, 6.8771,
        7.3810, 7.4081, 7.4207, 6.8444, 7.1000, 7.3412, 7.6077, 7.4789, 7.1278,
        7.2714, 7.6040, 7.5639, 7.7057, 7.1620, 7.2124, 7.4694, 7.3601, 7.3582,
        6.9931, 6.8663, 7.5334, 7.5324, 7.5423, 7.1064, 7.2983, 7.4181, 7.6120,
        7.5499, 7.1311, 7.1974, 7.2970, 7.6310, 7.5044, 6.9606, 7.2871, 6.8809,
        7.0939, 7.3840, 6.6142, 6.7813, 7.2404, 7.5975, 7.5058, 6.8825, 7.1646,
        7.0581, 7.2620, 7.3115, 6.7243, 6.7993, 7.5016, 7.5519, 7.3426, 6.9844,
        7.2765], device='cuda:0')
Average logit difference: 7.267223834991455


## Hypothesis of what is really going on?

1. copying? But how do you copy the concept of capital letter?
2. Reframing the question as why does the attention head attend to Capital pronoun rather than non capital pronoun after a fullstop?

In [None]:
model = HookedTransformer.from_pretrained(
    "attn-only-1l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

In [None]:
example_prompt = "Daniel is a great friend."
example_answer = "he"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

A 1L attention only model can perform this task. 4.36 for 15 sentences.

## Direct Logit attribution

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([100, 2, 768])
Logit difference directions shape: torch.Size([100, 768])


In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(sentences)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([100, 8, 768])
Calculated average logit diff: 8.223687171936035
Original logit difference: 7.267223834991455


## Logit Lens

In [None]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(sentences)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

In [None]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


## Attention Analysis

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [None]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

In [None]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    patched_head_z_diff.abs().flatten(), k=top_k
).indices
first_mid_layer = 7
first_late_layer = 9
early_heads = top_heads_by_output_patch[
    top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer
]
mid_heads = top_heads_by_output_patch[
    torch.logical_and(
        model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,
        top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,
    )
]
late_heads = top_heads_by_output_patch[
    model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch
]

early = visualize_attention_patterns(
    early_heads, cache, tokens[0], title=f"Top Early Heads"
)
mid = visualize_attention_patterns(
    mid_heads, cache, tokens[0], title=f"Top Middle Heads"
)
late = visualize_attention_patterns(
    late_heads, cache, tokens[0], title=f"Top Late Heads"
)

HTML(early + mid + late)

## Activation Patching

### Make Position Labels

In [None]:
for i, token in enumerate(model.to_str_tokens(tokens[0])):
    print(i, token)

0 <|endoftext|>
1 John
2  is
3  a
4  really
5  great
6  friend
7 .


From gendered pronouns discovery, important tokens include "name is person", in addittion to this for our own case the "." token is also important. Hence the important position labels include 1, 2, 6, 7

In [None]:
from collections import OrderedDict

positions = OrderedDict()

ones = torch.ones(size = (batch_size,)).long()

positions["name"] = ones.clone() * 1
positions["is"] = ones.clone() * 2
positions["person"] = ones.clone() * 6
positions["."] = ones.clone() * 7

### Make corrupted dataset

In [None]:
corrupted_prompts = []
for sentence in sentences:
    corrupted_prompt = sentence.replace('.', ';')
    corrupted_prompts.append(corrupted_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

Corrupted Average Logit Diff -5.04
Clean Average Logit Diff 7.27


### Metric

In [None]:
corrects = torch.tensor(model.tokenizer(corrects)["input_ids"]).squeeze()
wrongs = torch.tensor(model.tokenizer(wrongs)["input_ids"]).squeeze()

In [None]:
def eval_metric(model, tokens = tokens):
    logits = model(tokens)
    logits_on_correct = logits[torch.arange(batch_size), -1, corrects]
    logits_on_wrong = logits[torch.arange(batch_size), -1, wrongs]
    result = torch.mean(logits_on_correct - logits_on_wrong)
    return result.item()

In [None]:
model_performance = eval_metric(model, tokens)

### Residual stream

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

No movement of information in the residual stream, and all relevant computation is happening on the last token.

### Layers

In [None]:
patched_attn_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, answer_tokens
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, answer_tokens
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

In [None]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)

In [None]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)

### Head

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

## Tests

In [None]:
responses = [' He', ' he']
answer_tokens = []
answer_tokens.append([model.to_single_token(responses[0]), model.to_single_token(responses[1])])
answer_tokens = torch.tensor(answer_tokens)

In [None]:
print(answer_tokens)

tensor([[679, 339]])


In [None]:
final_logits = original_logits[:, -1, :]

In [None]:
print(final_logits)

tensor([[ 7.2207,  7.0948,  6.0607,  ..., -5.1576, -6.6794, 14.2832]])


In [None]:
final_logits = final_logits.gather(dim=-1, index=answer_tokens)

In [None]:
print(final_logits[:, 0] - final_logits[:, 1])

tensor([7.2195])


In [None]:
male_responses = [' He', ' he']
print(model.to_single_token(male_responses[1]))

339


In [None]:
prompts = ["David is a really great friend."]
tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
# tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)