TransformerLens is a tool to support mechanistic interpretability, by packaging a variety of model implementations in a consistent way, capturing internal state and supporting mutation of internal state.

It uses circuitsvis for visualisation. This provides several diagrams of the internal state of the model.


In [2]:
# Target google colab notebook
DEVELOPMENT_MODE = False
import google.colab
IN_COLAB = True
print("Running as a Colab notebook")
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-jsiyd0x5
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-jsiyd0x5
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting einops>=0.6.0
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxtyping>=0.2.11
  Downloading jaxtyping-0.2.15-py3-none-any.whl (20 kB)
Colle

In [3]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [4]:
#load model from HF - to reuse local cache
from transformers import GPTNeoXForCausalLM, AutoTokenizer
huggingface_model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-410m-deduped",
  cache_dir="../../data/external/pythia-410m-deduped/default",
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/911M [00:00<?, ?B/s]

# HookedTransformer

Implements several varieties of transformer algorithms with hookpoints to capture internal state.

## from_pretrained

Loads a pretrained model from huggingface. See https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py for available models


In [5]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("EleutherAI/pythia-410m", hf_model=huggingface_model)

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-410m into HookedTransformer


In [None]:
plaintext = "The capital of Ireland is"
tokens = model.to_tokens(plaintext)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
logits.shape, cache

The cache returned is suitable for rendering using CircuitsVis

The lookup keys however appear to depend on the model type.


In [None]:
for key in cache.keys() :
  print(key, cache[key].shape)

In [10]:
!pip install circuitsvis

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting circuitsvis
  Using cached circuitsvis-1.39.1-py3-none-any.whl (1.8 MB)
Collecting importlib-metadata<6.0.0,>=5.1.0
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Collecting torch<2.0,>=1.10
  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu11==8.5.0.96
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m50.0 MB/s[0m eta [3

In [23]:
import circuitsvis as cv
str_tokens = model.to_tokens(plaintext)
attention_pattern = cache["blocks.0.attn.hook_pattern"]

cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern)

CircuitsVis appears to require full Jupyter, not VSCode ones

In [15]:
rendered_html

In [14]:
print(rendered_html)

<div id="circuits-vis-b83180a2-d4a9" style="margin: 15px 0;"/>
    <script crossorigin type="module">
    import { render, AttentionHeads } from "https://unpkg.com/circuitsvis@1.39.1/dist/cdn/esm.js";
    render(
      "circuits-vis-b83180a2-d4a9",
      AttentionHeads,
      {"attention": [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.49172213673591614, 0.5082778334617615, 0.0, 0.0, 0.0, 0.0], [0.15476034581661224, 0.395854115486145, 0.44938549399375916, 0.0, 0.0, 0.0], [0.051755476742982864, 0.14524032175540924, 0.13158756494522095, 0.6714166402816772, 0.0, 0.0], [0.08859915286302567, 0.10238335281610489, 0.46161767840385437, 0.17370080947875977, 0.17369899153709412, 0.0], [0.10145179182291031, 0.03753535822033882, 0.04686075448989868, 0.31801190972328186, 0.06729642301797867, 0.4288436770439148]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16121850907802582, 0.838781476020813, 0.0, 0.0, 0.0, 0.0], [0.2535189092159271, 0.606117308139801, 0.14036379754543304, 0.0, 0.0, 0.0], [0.044216886162757874, 0.1

In [17]:
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)

torch.Size([16, 6, 6])


In [19]:
str_tokens = model.to_str_tokens(plaintext)
str_tokens

['<|endoftext|>', 'The', ' capital', ' of', ' Ireland', ' is']

In [21]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [26]:
cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern)

The line above works intermittently, but sometimes fails with JS exception

In [None]:
(str_tokens, attention_pattern)

Attempting to prepare a simple testcase

In [27]:
model2 = HookedTransformer.from_pretrained("EleutherAI/pythia-70m-deduped")

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/166M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [34]:
plaintext = "Word"
tokens = model2.to_tokens(plaintext)
logits, cache = model2.run_with_cache(tokens, remove_batch_dim=True)

In [35]:
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)

torch.Size([8, 2, 2])


In [42]:
str_tokens = model.to_str_tokens(plaintext)
cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern)

In [37]:
attention_pattern

tensor([[[1.0000, 0.0000],
         [0.1486, 0.8514]],

        [[1.0000, 0.0000],
         [0.6513, 0.3487]],

        [[1.0000, 0.0000],
         [0.0083, 0.9917]],

        [[1.0000, 0.0000],
         [0.0013, 0.9987]],

        [[1.0000, 0.0000],
         [0.3633, 0.6367]],

        [[1.0000, 0.0000],
         [0.6671, 0.3329]],

        [[1.0000, 0.0000],
         [0.8286, 0.1714]],

        [[1.0000, 0.0000],
         [0.6023, 0.3977]]], device='cuda:0')

In [38]:
small_pattern = [[[1.0000, 0.0000],
         [0.1486, 0.8514]],

        [[1.0000, 0.0000],
         [0.6513, 0.3487]],

        [[1.0000, 0.0000],
         [0.0083, 0.9917]],

        [[1.0000, 0.0000],
         [0.0013, 0.9987]],

        [[1.0000, 0.0000],
         [0.3633, 0.6367]],

        [[1.0000, 0.0000],
         [0.6671, 0.3329]],

        [[1.0000, 0.0000],
         [0.8286, 0.1714]],

        [[1.0000, 0.0000],
         [0.6023, 0.3977]]]

In [40]:
small_tensor = torch.tensor(small_pattern)
small_tensor

NameError: ignored

In [43]:
str_tokens

['<|endoftext|>', 'Word']