In [20]:
import warnings
warnings.filterwarnings('ignore')

from typing import Callable

import circuitsvis as cv
import torch
from huggingface_hub import hf_hub_download
from transformer_lens import HookedTransformerConfig, HookedTransformer, ActivationCache

In [9]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')

Using device: mps


# 1. Toy attention-only model

A two-layer transformer simplified to make it easier to interpret:
- it has only attention blocks
- there are separate embed and unembed matrices, so the weights are not tied.
- positional embeddings are only added to the residual stream before each key and query vector in the attention layers as opposed to the token embeddings, so queries are computed as `Q = (resid + pos_embed) @ W_Q + b_Q` and same for keys, but values as `V = resid @ W_V + b_V`. This means that the **residual stream can't directly encode positional information**.
  - this change makes it easier for induction heads to form. The bump in the green curve is the formation of induction heads.
  - the argument that implements this is `positional_embedding_type="shortformer"`.
  <div align="center">
    <img src="../assets/induction-heads.jpg" width="600"/>
  </div>

  This [diagram](https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/small-merm.svg) contains all relevant hook names.

In [None]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True,
    tokenizer_name="EleutherAI/gpt-neox-20b",
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", ie. layernorm with weights & biases
    positional_embedding_type="shortformer",
)

In [4]:
REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

In [10]:
model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(pretrained_weights)

<All keys matched successfully>

### 1.1. Basic attention patterns

Using notation `layer.head`, there are three basic patterns which repeat quite frequently:

- `prev_token_heads`, which attend mainly to the previous token (e.g. head `0.7`)
- `current_token_heads`, which attend mainly to the current token (e.g. head `1.6`)
- `first_token_heads`, which attend mainly to the first token (e.g. heads `0.3` or `1.4`)

The `prev_token_heads` and `current_token_heads` are perhaps unsurprising, because words that are close together in a sequence probably have a lot more mutual information (bigram or trigram prediction).

The `first_token_heads` are a bit more surprising. The basic intuition here is that the first token in a sequence is often used as a resting or null position for heads that only sometimes activate.

In [12]:
text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

In [15]:
str_tokens = model.to_str_tokens(text)
for layer in range(model.cfg.n_layers):
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))

### 1.2. Attention detectors

In [None]:
def generic_attn_detector( cache: ActivationCache, score_fn: Callable[[torch.Tensor], float]) -> list[str]:
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            score = score_fn(attention_pattern)
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

In [None]:
def current_attn_detector(cache: ActivationCache) -> list[str]:
    # take avg of diagonal elements
    return generic_attn_detector(cache, lambda pattern: pattern.diagonal().mean())

current_attn_detector(cache=cache)

['0.9']

In [None]:
def prev_attn_detector(cache: ActivationCache) -> list[str]:
    # take avg of sub-diagonal elements
    return generic_attn_detector(cache, lambda pattern: pattern.diagonal(-1).mean())

prev_attn_detector(cache=cache)

['0.7']

In [None]:
def first_attn_detector(cache: ActivationCache) -> list[str]:
    # take avg of 0th elements
    return generic_attn_detector(cache, lambda pattern: pattern[:, 0].mean())

first_attn_detector(cache=cache)

['0.3', '1.4', '1.10']

# 2. Induction heads and circuits

Induction circuits develop fairly suddenly in a phase change, characterized by a narrow window early in training (roughly 2.5 to 5 billion tokens), when the neural network goes from no induction heads to pretty well developed ones, which remain constant for the rest of training. They are responsible for a significant loss decrease - so much so that there is a visible bump in the loss curve when they develop.

Induction heads seem to be responsible for the vast majority of in-context learning - the ability to use far back tokens in the context to predict the next token. This is a significant way in which transformers outperform older architectures.

**DEFINITIONS**:
- induction head: head which attends to the token immediately after copies of the current token via K-Composition with the previous token head
- induction circuit: the composition of a previous token head and the current induction head

Question - Why couldn't an induction head form in a 1L model?

# Sources

1. [Ground truth - Finding induction heads, by ARENA](https://arena-chapter1-transformer-interp.streamlit.app/[1.2]_Intro_to_Mech_Interp)
2. [In-context Learning and Induction Heads, by Neel Nanda, Chris Olah, Dario Amodei](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)
3. [Induction circuits - glossary definition, by Neel Nanda](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_Jzi6YHRHKP1JziwdE02qdYZ)
4. [Induction heads - illustrated, by Callum McDougall](https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated)