In [None]:
IN_COLAB = False
DEVELOPMENT_MODE = False

In [None]:
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled>

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


##### Example 1

In [None]:
prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"

In [None]:
correct_answer = " John"
incorrect_answer = " Mary"

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
prompt

'After John and Mary went to the store, Mary gave a bottle of milk to'

In [None]:
correct_answer, incorrect_answer

(' John', ' Mary')

Find the logit difference between `correct_answer` and `incorrect_answer`

In [None]:
tokens = model.to_tokens(prompt)

In [None]:
def logits_to_logit_diff(logits, correct_answer, incorrect_answer):
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

In [None]:
logits = model(tokens)

In [None]:
diff = logits_to_logit_diff(logits, correct_answer, incorrect_answer)

In [None]:
diff

tensor(4.2765)

##### Example 1.1

In [None]:
tokens = model.to_tokens("Persistence is all you need")

In [None]:
hook_name = "blocks.1.attn.hook_v"

In [None]:
tokens

tensor([[50256, 30946, 13274,   318,   477,   345,   761]])

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

`hook_name` is the name of the hook point

In [None]:
hook_name

'blocks.1.attn.hook_v'

Print the shape of activation at hook point `hook_name`

**Hints**:
- A hook function takes `activation` and `hook` as argument
- Ignore the type annotations

In [None]:
def hook_func(
    activation: Float[torch.Tensor, "batch seq_len head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch seq_len head_index d_head"]:
    print(f"Shape of the activation tensor: {activation.shape}")

In [None]:
logits = model.run_with_hooks(
    tokens,
    fwd_hooks=[(
        hook_name,
        hook_func
    )]
)

Shape of the activation tensor: torch.Size([1, 7, 12, 64])


##### Example 2.2

In [None]:
original_loss = model.run_with_hooks(
    tokens,
    return_type="loss"
)

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
tokens

tensor([[50256, 30946, 13274,   318,   477,   345,   761]])

`original_loss` is the loss of `tokens` without any changes made to the activation

In [None]:
original_loss, hook_name

(tensor(4.0712), 'blocks.1.attn.hook_v')

Set the activation of **head_index 4 to zero** at hook point `hook_name` and return the loss

**Hints**:
- The shape of activation is `(batch_size, seq_len, head_index, d_head)`
- Ignore the type annotations

In [None]:
def hook_func(
    activation: Float[torch.Tensor, "batch seq_len head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch seq_len head_index d_head"]:
    activation[:, :, 4, :] = 0.

In [None]:
loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[(hook_name, hook_func)]
)

In [None]:
loss

tensor(4.0319)

##### Example 2

In [None]:
import torch.nn.functional as F

In [None]:
logits = model(corrupted_tokens)

In [None]:
# hook_name = "blocks.3.attn.hook_attn_scores"
# hook_name = "blocks.3.hook_mlp_out"
hook_name = "blocks.1.hook_resid_pre"

In [None]:
clean_tokens = model.to_tokens("Persistence is all you need")

In [None]:
corrupted_tokens = model.to_tokens("**Persistence **is **all **you **need")

In [None]:
clean_logits, clean_acts = model.run_with_cache(clean_tokens)

In [None]:
len(clean_tokens[0]), len(corrupted_tokens[0])

In [None]:
hook_name, type(clean_acts), type(model)

`clean_acts` is the activations of the `model` when given `clean_tokens` as input

In [None]:
clean_acts[hook_name]

Perform activation patching at `hook_name` by replacing the clean token
logits with the corrupted token logits at position 1

**Hint**:
- The activation has shape `(batch_size, seq_len, d_model)`
- Ignore the type annotations

In [None]:
def act_patching_hook(
    acts: Float[torch.Tensor, "batch seq_len d_model"],
    hook: HookPoint,
):
    clean_resid_pre = clean_acts[hook.name]
    acts[:, 1, :] = clean_resid_pre[:, 1, :]
    return acts

In [None]:
patched_logits = model.run_with_hooks(
    corrupted_tokens, fwd_hooks=[(hook_name, act_patching_hook)]
)

Compare the logits of `corrupted_tokens` before and after applying the activation patch at `hook_name`

In [None]:
patched_logits == logits

In [None]:
n_positions = len(clean_tokens[0])

In [None]:
for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(n_positions):
        

##### Example 3

In [None]:
x = torch.arange(0, 2*3*4*4).view(2, 3, 4, 4)

In [None]:
x

In [None]:
x.shape

Extract the diagonal matrix as bellow and explain

In [None]:
diagonal = x.diagonal(dim1=-2, dim2=-1)

**Explain**
- `dim1=-2` means the rows of the diagonal are specified by the second-to-last dimension of `x`.
- `dim2=-1` means the columns of the diagonal are specified by the last dimension of `x`.

In summary, it means that you want to extract the diagonals along the last two dimensions of the tensor (the second-to-last dimension for the rows and the last dimension for the columns).

In [None]:
diagonal.shape

In [None]:
diagonal

##### Example 4 

In [None]:
attn_weights = torch.arange(0, 2*3*4*4).view(2, 3, 4, 4)

In [None]:
mask = torch.tril(torch.ones_like(attn_weights))

In [None]:
attn_weights = attn_weights * mask

In [None]:
attn_weights

The `attn_weights` tensor contains `2` batches, `3` attention heads, and `4` tokens within each sequence

In [None]:
attn_weights.shape

Extract the attention weights of each token looking back at the previous token. And explain the code.

In [None]:
weights = attn_weights.diagonal(dim1=-2, dim2=-1, offset=-1)

**Explain**
- `dim1=-2`: This means you want to consider the-second-last dimension of the `attn_weights` tensor (which represents the destination positions) as the first dimension (rows) of the resulting diagonal tensor.
- `dim2=-1`: This means you want to consider the last dimension of the `attn_weights` tensor (which represents the source positions) as the second dimension (columns) of the resulting diagonal tensor.
- `offset=-1`: This means that you want to extract the diagonal below the main diagonal (i.e., one step to the left of the main diagonal). In the context of attention weights, this corresponds to the attention paid from each token to its immediately preceding token (one token look back).

In [None]:
weights.shape

In [None]:
weights

### Factored Matrix Class


In [None]:
A = torch.randn(5, 2)

In [None]:
B = torch.randn(2, 5)

In [None]:
AB = A @ B

In [None]:
AB_factor = FactoredMatrix(A, B)

In [None]:
AB_factor.norm()

AB.norm()

### Hook Points


##### Example 1

In [None]:
x = torch.randn(42)

In [None]:
model = nn.Sequential(
    nn.Linear(42, 69),
    nn.ReLU(),
    nn.Linear(69, 42)
)

In [None]:
model

Create a `Model` that is similar to `model` and records all intermediate activations

In [None]:
from transformer_lens.hook_points import HookedRootModule, HookPoint

In [None]:
class Model(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(42, 69), HookPoint(),
            nn.ReLU(), HookPoint(),
            nn.Linear(69, 42), HookPoint()
        )
        super().setup()
    
    def forward(self, x):
        x_out = self.net(x)
        return x_out

In [None]:
model = Model()

In [None]:
x.shape

In [None]:
output, acts = model.run_with_cache(x)

In [None]:
for name in acts:
    print(f"activation name: {name}, shape {acts[name].shape}")

In [None]:
output.shape, acts.keys()

In [None]:
acts["net.1"]

##### Example 2

In [None]:
batch_size = 10

In [None]:
seq_len = 50

In [None]:
text = "I like ice cream. I like ice cream. I like ice cream. I like ice cream. I like ice"

In [None]:
import torch.nn.functional as F
from einops import rearrange, reduce

In [None]:
type(model)

transformer_lens.HookedTransformer.HookedTransformer

In [None]:
text

'I like ice cream. I like ice cream. I like ice cream. I like ice cream. I like ice'

Calculate the induction loss for each position in `text` from scratch

In [None]:
tokens = model.to_tokens(text)

In [None]:
repeated_logits = model(tokens)

In [None]:
log_probs = F.log_softmax(repeated_logits, dim=-1)

In [None]:
log_probs.shape

torch.Size([1, 24, 50257])

In [None]:
last_token_logits = log_probs[:, -1, :]

In [None]:
target_tokens = tokens[:, 1:]

In [None]:
target_tokens = rearrange(target_tokens, "... -> ... 1")

In [None]:
target_tokens.shape

torch.Size([1, 23, 1])

In [None]:
predicted_log_probs = -log_probs.gather(dim=-1, index=target_tokens)

In [None]:
predicted_log_probs.shape

torch.Size([1, 23, 1])

In [None]:
line(rearrange(predicted_log_probs, "1 ... -> ..."), xaxis="Position", yaxis="Loss", title="Loss by position on repeated tokens")

##### Example 3

In [None]:
from circuitsvis.tokens import colored_tokens

In [None]:
colored_tokens(["My", "tokens"], [0.123, -0.226])