
This notebook describes the pythia model, it's implementation on Transformer Lens, and how the hooks map to the model.


# Setup

Preinstall numpy 1.23 to avoid warning when current default version 1.24 is installed. Requires restart of kernel.


In [1]:
%pip install "numpy == 1.23.*"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-0abysqaw
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-0abysqaw
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit c268a7159a6f8d5c78236a3f958f2d704fbc940f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens==0.0.0)
  Downloading einops-0.6

# Pythia Model

The EleutherAI's Pythia project released a series of transformer models which support interpretability research through use of standardised training data and process, and release of checkpoints from training. See  https://github.com/EleutherAI/pythia and https://arxiv.org/pdf/2304.01373.pdf.

# HookedTransformer


## Summary

The keys available in each transformer block are:

|Name|Description|Dimensions(example)|Dimensions (general)|
|---|---|---|---|
|hook_resid_pre|The residual input to the transformer block| 2 x 512 |
|ln1.hook_scale|The normalisation scale factor| 2 x 1| batch, pos, 1 |
|ln1.hook_normalized| Result of normalization | 2 x 512|batch, pos, length|
|attn.hook_q|Internal attn head vector - query | 2 x 8 x 64|[batch, pos, head_index, d_head]|
|attn.hook_k|Internal attn head vector - key | 2 x 8 x 64|[batch, pos, head_index, d_head]|
|attn.hook_v|Internal attn head vector - value| 2 x 8 x 64|[batch, pos, head_index, d_head]|
|attn.hook_rot_q| Internal attn head vector for rotary embedding position | 2 x 8 x 64| batch pos head_index d_head |
|attn.hook_rot_k| Internal attn head vector for rotary embedding position | 2 x 8 x 64| batch pos head_index d_head |
|attn.hook_attn_scores| attn_scores refers to query key dot product immediately before attention softmax| 8 x 2 x 2| batch, head_index, query_pos, key_pos |
|attn.hook_pattern| The attention pattern calculated from query key dot product and softmax | |batch, head_index, query_pos, key_pos |
|attn.hook_z| weighted sum of values after applying the attention pattern  | |batch query_pos head_index d_head|
|attn.hook_result| Seperate per-head result of attention projected back to model dimensions - disabled by default due to memory usage|batch pos head_index d_model | | |
|hook_attn_out| Result of the attention operation, to be added to residual|||
|resid_mid | Result of adding attention output to residual, before normalization or MLP||
|ln2.hook_scale| |
|ln2.hook_normalized| |
|normalized_resid_mid  | Result of ln2(resid_mid)||
|mlp.hook_pre| |
|mlp.hook_post| |
|hook_mlp_out| Result of the MLP operation, to be added to residual |
|hook_resid_post|The residual resulting from the transformer block |



## Details

### Pythia Model taken from HuggingFace

Show structure of the model from HuggingFace 

In [None]:
#load model from HF
from transformers import GPTNeoXForCausalLM, AutoTokenizer
hf_pythia_model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-70m-deduped"
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/166M [00:00<?, ?B/s]

In [4]:
hf_pythia_model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attention): GPTNeoXAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=512, out_features=50304, bias=False)
)

### Transformer Lens version of Pythia model


Show equivalent HookedTransformer from Transformer Lens


In [5]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m-deduped")

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [6]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPo

### Hooks available from Transformer Lens HookedTransformer cache

Show values from hooks available in the cache after inference

In [7]:
plaintext = "Two tokens"
tokens = model.to_tokens(plaintext, prepend_bos=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
cache

ActivationCache with keys ['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'bloc

In [8]:
for key in cache.keys() :
  print(key, cache[key].shape)

hook_embed torch.Size([2, 512])
blocks.0.hook_resid_pre torch.Size([2, 512])
blocks.0.ln1.hook_scale torch.Size([2, 1])
blocks.0.ln1.hook_normalized torch.Size([2, 512])
blocks.0.attn.hook_q torch.Size([2, 8, 64])
blocks.0.attn.hook_k torch.Size([2, 8, 64])
blocks.0.attn.hook_v torch.Size([2, 8, 64])
blocks.0.attn.hook_rot_q torch.Size([2, 8, 64])
blocks.0.attn.hook_rot_k torch.Size([2, 8, 64])
blocks.0.attn.hook_attn_scores torch.Size([8, 2, 2])
blocks.0.attn.hook_pattern torch.Size([8, 2, 2])
blocks.0.attn.hook_z torch.Size([2, 8, 64])
blocks.0.hook_attn_out torch.Size([2, 512])
blocks.0.ln2.hook_scale torch.Size([2, 1])
blocks.0.ln2.hook_normalized torch.Size([2, 512])
blocks.0.mlp.hook_pre torch.Size([2, 2048])
blocks.0.mlp.hook_post torch.Size([2, 2048])
blocks.0.hook_mlp_out torch.Size([2, 512])
blocks.0.hook_resid_post torch.Size([2, 512])
blocks.1.hook_resid_pre torch.Size([2, 512])
blocks.1.ln1.hook_scale torch.Size([2, 1])
blocks.1.ln1.hook_normalized torch.Size([2, 512])
blo