TransformerLens is a tool to support mechanistic interpretability, by packaging a variety of model implementations in a consistent way, capturing internal state and supporting mutation of internal state.

It uses circuitsvis for visualisation. This provides several diagrams of the internal state of the model.


In [40]:
#load model from HF - to reuse local cache
from transformers import GPTNeoXForCausalLM, AutoTokenizer
huggingface_model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-410m-deduped",
  cache_dir="../../data/external/pythia-410m-deduped/default",
)

# HookedTransformer

Implements several varieties of transformer algorithms with hookpoints to capture internal state.

## from_pretrained

Loads a pretrained model from huggingface. See https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py for available models


In [41]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("EleutherAI/pythia-410m", hf_model=huggingface_model)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-410m into HookedTransformer


In [42]:
plaintext = "The capital of Ireland is"
tokens = model.to_tokens(plaintext)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
logits.shape, cache

(torch.Size([1, 6, 50304]),
 ActivationCache with keys ['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', '

The cache returned is suitable for rendering using CircuitsVis

The lookup keys however appear to depend on the model type.


In [43]:
for key in cache.keys() :
  print(key, cache[key].shape)

hook_embed torch.Size([6, 1024])
blocks.0.hook_resid_pre torch.Size([6, 1024])
blocks.0.ln1.hook_scale torch.Size([6, 1])
blocks.0.ln1.hook_normalized torch.Size([6, 1024])
blocks.0.attn.hook_q torch.Size([6, 16, 64])
blocks.0.attn.hook_k torch.Size([6, 16, 64])
blocks.0.attn.hook_v torch.Size([6, 16, 64])
blocks.0.attn.hook_rot_q torch.Size([6, 16, 64])
blocks.0.attn.hook_rot_k torch.Size([6, 16, 64])
blocks.0.attn.hook_attn_scores torch.Size([16, 6, 6])
blocks.0.attn.hook_pattern torch.Size([16, 6, 6])
blocks.0.attn.hook_z torch.Size([6, 16, 64])
blocks.0.hook_attn_out torch.Size([6, 1024])
blocks.0.ln2.hook_scale torch.Size([6, 1])
blocks.0.ln2.hook_normalized torch.Size([6, 1024])
blocks.0.mlp.hook_pre torch.Size([6, 4096])
blocks.0.mlp.hook_post torch.Size([6, 4096])
blocks.0.hook_mlp_out torch.Size([6, 1024])
blocks.0.hook_resid_post torch.Size([6, 1024])
blocks.1.hook_resid_pre torch.Size([6, 1024])
blocks.1.ln1.hook_scale torch.Size([6, 1])
blocks.1.ln1.hook_normalized torch.Si

In [45]:
import circuitsvis as cv
str_tokens = model.to_tokens(plaintext)
attention_pattern = cache["blocks.0.attn.hook_pattern"]

rendered_html = cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern)

CircuitsVis appears to require full Jupyter, not VSCode ones

In [55]:
print(rendered_html)

<div id="circuits-vis-3d172f80-326d" style="margin: 15px 0;"/>
    <script crossorigin type="module">
    import { render, AttentionHeads } from "https://unpkg.com/circuitsvis@1.39.1/dist/cdn/esm.js";
    render(
      "circuits-vis-3d172f80-326d",
      AttentionHeads,
      {"attention": [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4517521262168884, 0.5482479333877563, 0.0, 0.0, 0.0, 0.0], [0.5024253129959106, 0.4261503517627716, 0.07142426073551178, 0.0, 0.0, 0.0], [0.11224543303251266, 0.23465727269649506, 0.44676706194877625, 0.2063302844762802, 0.0, 0.0], [0.1720249354839325, 0.11507517099380493, 0.04611247032880783, 0.5142410397529602, 0.15254639089107513, 0.0], [0.1090007945895195, 0.10949338227510452, 0.0036624169442802668, 0.20431481301784515, 0.03385785222053528, 0.5396707653999329]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5926687121391296, 0.407331258058548, 0.0, 0.0, 0.0, 0.0], [0.4600315988063812, 0.34202510118484497, 0.19794337451457977, 0.0, 0.0, 0.0], [0.15562257170677185, 0.3513