In [None]:
!pip install transformers==4.37.2 transformer_lens==2.11.0 --quiet


In [None]:
from transformer_lens import (
    ActivationCache,
    HookedTransformer,
    HookedTransformerConfig,
)
import numpy as np
import torch as t

In [None]:
t.manual_seed(0)
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gpt2-small").to(device)



Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


In [None]:
reference_text = "The quick brown fox jumps over the lazy dog"

In [None]:
tokens = model.to_tokens(reference_text).to(device)
print("Tokens:", [model.to_string(tok) for tok in tokens[0]])

fox_token_id = model.to_single_token(" fox")
print(f"Target ' fox' token ID: {fox_token_id}")

Tokens: ['<|endoftext|>', 'The', ' quick', ' brown', ' fox', ' jumps', ' over', ' the', ' lazy', ' dog']
Target ' fox' token ID: 21831


In [None]:
# Forward pass: save all activations in cache
logits, cache = model.run_with_cache(reference_text)

In [None]:
W_U = model.W_U  # [d_model, d_vocab]

print("W_U shape:", W_U.shape)

W_U shape: torch.Size([768, 50257])


In [None]:
# check when fox token starts to appear in the models predictions
n_layers = model.cfg.n_layers
layer_probs = []

for layer in range(n_layers):
  resid_post = cache[f"blocks.{layer}.hook_resid_post"].to(device)
  hidden_state = resid_post[0,1,:] # second token at 1st pos

  logits = hidden_state@ W_U # [d_vocab], we decode at each layer
  probs = t.softmax(logits, dim =-1)

  fox_prob = probs[fox_token_id].item()  # the chance ' fox' is next, given what the model knows so far

  layer_probs.append(fox_prob)

  print(f"Layer {layer:2d}: P(' fox') = {fox_prob}")


Layer  0: P(' fox') = 4.227344030383051e-10
Layer  1: P(' fox') = 3.8766975740678333e-11
Layer  2: P(' fox') = 2.829986462185574e-12
Layer  3: P(' fox') = 2.7017364040426983e-13
Layer  4: P(' fox') = 7.391197851038e-16
Layer  5: P(' fox') = 6.825884075029927e-20
Layer  6: P(' fox') = 6.2337102183971546e-24
Layer  7: P(' fox') = 3.872848673308404e-28
Layer  8: P(' fox') = 1.5845348816598415e-31
Layer  9: P(' fox') = 1.1578290039250028e-35
Layer 10: P(' fox') = 8.898455443232237e-40
Layer 11: P(' fox') = 2.9665488489756377e-40


In [None]:
# actually max prob is at layer 0 since model learns that ' fox' is unlikely to go after second token as we progress through the layers
max_prob = max(layer_probs)
max_layer = layer_probs.index(max_prob)
print(f"Max P(' fox'): {max_prob}, at layer {max_layer}")


Max P(' fox'): 4.227344030383051e-10, at layer 0
