In [93]:
import warnings
warnings.filterwarnings('ignore')

import torch
import transformer_lens

# 1. Inspection of the model (GPT-2)

In [None]:
model = transformer_lens.HookedTransformer.from_pretrained('gpt2-small')

Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads
n_ctx = model.cfg.n_ctx

print(f'Number of layers: {n_layers}')
print(f'Number of heads per layer: {n_heads}')
print(f'Maximum context window: {n_ctx}')

Number of layers: 12
Number of heads per layer: 12
Maximum context window: 1024


In [None]:
logits = model('Hello TransformerLens!', return_type='logits')
print('Shape of model logits: ', logits.shape)

Shape of model logits:  torch.Size([1, 6, 50257])


In [None]:
d_embeddings, d_model = model.W_E.shape
print(f'Shape of `embeddings` matrix: {d_embeddings} x {d_model}')

Shape of `embeddings` matrix: 50257 x 768


In [25]:
d_embeddings, d_model = model.W_pos.shape
print(f'Shape of `positional embeddings` matrix: {d_embeddings} x {d_model}')

Shape of `positional embeddings` matrix: 1024 x 768


In [20]:
_, _, d_model, d_head = model.W_Q.shape
print(f'Shape of `query` matrix: {d_model} x {d_head}')

Shape of `query` matrix: 768 x 64


In [21]:
_, _, d_model, d_head = model.W_K.shape
print(f'Shape of `key` matrix: {d_model} x {d_head}')

Shape of `key` matrix: 768 x 64


In [22]:
_, _, d_model, d_head = model.W_V.shape
print(f'Shape of `value` matrix: {d_model} x {d_head}')

Shape of `value` matrix: 768 x 64


In [26]:
d_embeddings, d_model = model.W_U.shape
print(f'Shape of `unembeddings` matrix: {d_embeddings} x {d_model}')

Shape of `unembeddings` matrix: 768 x 50257


In [None]:
loss = model('Hello TransformerLens!', return_type='loss')
print('Loss: ', loss)

Loss:  tensor(7.2929, device='mps:0', grad_fn=<DivBackward0>)


# 2. Tokenization

GPT-2 uses `<|endoftext|>` as Beginning of Sequence (BOS), End of Sequence (EOS) and Padding (PAD) tokens - index 50256.

**TransformerLens** appends this token by default, inclusive in `model.forward`, which is what is implicitly used when `model("Hello World")` is run. To disable this behavior, set the flag prepend_bos=False in `to_tokens`, `to_str_tokens`, `model.forward` and any other function that converts strings to multi-token tensors.

In [70]:
model.to_str_tokens('Hello TransformerLens!')

['<|endoftext|>', 'Hello', ' Trans', 'former', 'Lens', '!']

In [77]:
model.to_tokens('Hello TransformerLens!')

tensor([[50256, 15496,  3602, 16354, 49479,     0]], device='mps:0')

In [78]:
model.to_string(model.to_tokens('Hello TransformerLens!'))

['<|endoftext|>Hello TransformerLens!']

In [73]:
text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""

In [74]:
logits = model(text, return_type='logits')
print('Shape of model logits: ', logits.shape)

Shape of model logits:  torch.Size([1, 112, 50257])


In [76]:
predictions = logits.argmax(dim=-1).squeeze()[:-1]
print('Shape of predictions: ', predictions.shape)

Shape of predictions:  torch.Size([111])


In [81]:
true_tokens = model.to_tokens(text).squeeze()[1:]
is_correct = predictions == true_tokens

print(f"Model accuracy: {is_correct.sum()}/{len(true_tokens)}")

Model accuracy: 33/111


**Induction heads** are a special kind of attention head that allow a model to perform in-context learning of a specific form: generalising from one observation that token B follows token A, to predict that token B will follow A in future occurrences of A, even if these two tokens had never appeared together in the training data.

The evidence below for induction heads comes from the fact that the model successfully predicted 'ooked', 'Trans', 'former' following the token 'H'. This is because it is the second time that HookedTransformer had appeared in this text string, and the model predicted it the second time but not the first. The model did predict `former` the first time, but we can reasonably assume that `Transformer` is a word this model had already been exposed to during training, so this prediction would not require the induction capability, unlike `HookedTransformer`.

In [87]:
print(f"Evidence of induction heads: {model.to_str_tokens(predictions[is_correct])[8:11]}")

Evidence of induction heads: ['ooked', 'Trans', 'former']


# 3. Caching activations

The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model.

Every activation inside the model begins with a batch dimension. Here, because we only entered a single batch dimension, that dimension is always length 1, so passing in the `remove_batch_dim=True` keyword or calling `model.remove_batch_dim()` removes it.

In [88]:
text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."

In [89]:
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

print(type(logits), type(cache))

<class 'torch.Tensor'> <class 'transformer_lens.ActivationCache.ActivationCache'>


In [None]:
# accessing attention patterns for layer 0 (two different ways)
# 
# the reason these are the same is that, under the hood, the first example actually
#  indexes by `utils.get_act_name("pattern", 0)`, which evaluates to "blocks.0.attn.hook_pattern"
# 
# the diagram from the Transformer Architecture section helps in finding activation names
# https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/full-merm.svg
attn_patterns_from_shorthand = cache["pattern", 0]
attn_patterns_from_full_name = cache["blocks.0.attn.hook_pattern"]

torch.testing.assert_close(attn_patterns_from_shorthand, attn_patterns_from_full_name)

# Sources

1. [Ground truth - Intro do Mech Interp, by ARENA](https://arena-chapter1-transformer-interp.streamlit.app/[1.2]_Intro_to_Mech_Interp)