##### Example 1 

In [9]:
import torch

In [10]:
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

In [23]:
tensor

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

Split all elements along dimension 0 into a list of tuples

In [24]:
slices = tensor.unbind(dim=0)

In [25]:
type(slices)

tuple

In [26]:
for slice_tensor in slices:
    print(slice_tensor)

tensor([1, 2, 3])
tensor([4, 5, 6])
tensor([7, 8, 9])


##### Example 2

In [28]:
from transformer_lens import HookedTransformer

In [29]:
model = HookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [30]:
text = "Persistence is all you need."

In [31]:
tokens = model.to_tokens(text)

In [41]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [42]:
tokens

tensor([[50256, 30946, 13274,   318,   477,   345,   761,    13]])

Record all the activations with a hook name that ends with "mlp_out"

In [43]:
_, cache = model.run_with_cache(
    tokens,
    names_filter=lambda name: name.endswith("mlp_out")
)

In [44]:
cache

ActivationCache with keys ['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']

##### Example 3

In [50]:
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads

In [62]:
n_layers, n_heads

(12, 12)

Generate the list of all possible combinations of layers and heads `(head_idx, layer_idx)` in the model

In [63]:
from itertools import product

In [64]:
combinations = list(product(range(n_heads), range(n_layers)))

In [65]:
len(combinations)

144

In [66]:
combinations

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 8),
 (0, 9),
 (0, 10),
 (0, 11),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (1, 9),
 (1, 10),
 (1, 11),
 (2, 0),
 (2, 1),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 6),
 (2, 7),
 (2, 8),
 (2, 9),
 (2, 10),
 (2, 11),
 (3, 0),
 (3, 1),
 (3, 2),
 (3, 3),
 (3, 4),
 (3, 5),
 (3, 6),
 (3, 7),
 (3, 8),
 (3, 9),
 (3, 10),
 (3, 11),
 (4, 0),
 (4, 1),
 (4, 2),
 (4, 3),
 (4, 4),
 (4, 5),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (5, 9),
 (5, 10),
 (5, 11),
 (6, 0),
 (6, 1),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 5),
 (6, 6),
 (6, 7),
 (6, 8),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 1),
 (7, 2),
 (7, 3),
 (7, 4),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 9),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 3),
 (8, 4),
 (8, 5),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 9),
 (8, 10),
 (8, 11),
 (9, 0),
 