In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEBUG_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install torchtyping
    # Install my janky personal plotting utils
    %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git
    # Install another version of node that makes PySvelte work way faster
    %pip install circuitsvis
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"

In [3]:
# Imports
import torch

from transformers import AutoTokenizer
from transformer_lens import HookedEncoderDecoder

model_name = "t5-small"
model = HookedEncoderDecoder.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development



generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Moving model to device:  cuda
Loaded pretrained model t5-small into HookedTransformer


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development



spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x286ea2c85c0>

## basic sanity check - model generates smth

In [5]:
prompt = "translate English to French: Hello, how are you? "
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)


while True:
    logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)
    # logits.shape == (batch_size (1), predicted_pos, vocab_size)

    token_idx = torch.argmax(logits[0, -1, :]).item()
    print("generated token: \"", tokenizer.decode(token_idx), "\", token id: ", token_idx, sep="")

    # append token to decoder_input_ids
    decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)

    # break if End-Of-Sequence token generated
    if token_idx == tokenizer.eos_token_id:
        break

print(prompt, "\n", tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True))

generated token: "Bonjour", token id: 21845
generated token: ",", token id: 6
generated token: "comment", token id: 1670
generated token: "", token id: 3
generated token: "êtes", token id: 6738
generated token: "-", token id: 18
generated token: "vous", token id: 3249
generated token: "?", token id: 58
generated token: "</s>", token id: 1
translate English to French: Hello, how are you?  
 Bonjour, comment êtes-vous?


### visualise encoder patterns

In [6]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [7]:
prompt = "translate English to French: Hello, how are you? "
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]


logits,cache = model.run_with_cache(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids, remove_batch_dim=True)

In [8]:
# the usual way of indexing cache via cache["pattetn",0,"attn"] does not work
# besause it uses cache["block.0....]  indexing
# t5 is implementes as separate stack of blocks for encoder and decoder
# so indexing is cache["encoder.0.."], cache["decoder.0.."] 
# lets see what is in cache and choose the right key for encoder attention pattern on layer 0
print("\n".join(cache.keys()))

hook_embed
encoder.0.hook_resid_pre
encoder.0.ln1.hook_scale
encoder.0.ln1.hook_normalized
encoder.0.attn.hook_q
encoder.0.attn.hook_k
encoder.0.attn.hook_v
encoder.0.attn.hook_attn_scores
encoder.0.attn.hook_pattern
encoder.0.attn.hook_z
encoder.0.hook_attn_out
encoder.0.hook_resid_mid
encoder.0.ln2.hook_scale
encoder.0.ln2.hook_normalized
encoder.0.mlp.hook_pre
encoder.0.mlp.hook_post
encoder.0.hook_mlp_out
encoder.0.hook_resid_post
encoder.1.hook_resid_pre
encoder.1.ln1.hook_scale
encoder.1.ln1.hook_normalized
encoder.1.attn.hook_q
encoder.1.attn.hook_k
encoder.1.attn.hook_v
encoder.1.attn.hook_attn_scores
encoder.1.attn.hook_pattern
encoder.1.attn.hook_z
encoder.1.hook_attn_out
encoder.1.hook_resid_mid
encoder.1.ln2.hook_scale
encoder.1.ln2.hook_normalized
encoder.1.mlp.hook_pre
encoder.1.mlp.hook_post
encoder.1.hook_mlp_out
encoder.1.hook_resid_post
encoder.2.hook_resid_pre
encoder.2.ln1.hook_scale
encoder.2.ln1.hook_normalized
encoder.2.attn.hook_q
encoder.2.attn.hook_k
encoder.2

In [9]:
encoder_attn_pattern = cache["encoder.0.attn.hook_pattern"]
input_str_tokens = [w.lstrip("▁") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]

In [10]:

cv.attention.attention_patterns(tokens=input_str_tokens, attention=encoder_attn_pattern)

### visualise decoder pattern

In [11]:
decoder_str_tokens = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
decoder_str_tokens

['<pad>', '▁Bonjour', ',', '▁comment', '▁', 'êtes', '-', 'vous', '?', '</s>']

In [12]:
decoder_attn_pattern = cache["decoder.0.attn.hook_pattern"]
cv.attention.attention_patterns(tokens=decoder_str_tokens, attention=decoder_attn_pattern)

## topk tokens visualisation

In [13]:
# list of samples of shape (n_layers, n_tokens, n_neurons) for each sample
# i take the activations after the mlp layer
# you can also pass the activations after the attention layer (hook_attn_out),
#  after the cross attention layer (hook_cross_attn_out) or after the mlp layer (hook_mlp_out)
activations = [
    torch.stack([cache[f"decoder.{layer}.hook_mlp_out"] for layer in range(model.cfg.n_layers)]).cpu().numpy()
    ]

# list of samples of shape (n_tokens)
tokens = [decoder_str_tokens]

# if we have an arbitrary selection of layers, when change the layer labels, now just pass the layer index
layer_labels = [i for i in range(model.cfg.n_layers)]


cv.topk_tokens.topk_tokens(
    tokens=tokens,
    activations=activations, 
    max_k=10, 
    first_dimension_name="Layer", 
    first_dimension_labels=layer_labels,
    third_dimension_name="Neuron",
)
