In [2]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

  from .autonotebook import tqdm as notebook_tqdm


Disabled automatic differentiation


In [3]:
model = HookedTransformer.from_pretrained("gpt2-small")

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
prompt = "Paris is the capital of"
answer = "France"
utils.test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'Paris', ' is', ' the', ' capital', ' of']
Tokenized answer: [' France']


Top 0th token. Logit: 16.93 Prob: 33.42% Token: | France|
Top 1th token. Logit: 16.61 Prob: 24.46% Token: | the|
Top 2th token. Logit: 15.30 Prob:  6.59% Token: | a|
Top 3th token. Logit: 14.68 Prob:  3.52% Token: | Europe|
Top 4th token. Logit: 13.98 Prob:  1.75% Token: | French|
Top 5th token. Logit: 13.93 Prob:  1.67% Token: | Paris|
Top 6th token. Logit: 13.78 Prob:  1.43% Token: | Belgium|
Top 7th token. Logit: 13.75 Prob:  1.39% Token: | one|
Top 8th token. Logit: 13.22 Prob:  0.82% Token: | an|
Top 9th token. Logit: 12.86 Prob:  0.57% Token: | Germany|


In [5]:
logits, cache = model.run_with_cache(prompt, remove_batch_dim= True)

In [6]:
def plot_top_tokens(idx):
    attrs = idx@model.W_U
    top_indices = torch.argsort(attrs, descending=True)[:10]
    return imshow(attrs[top_indices].unsqueeze(0), x = model.to_str_tokens(top_indices), title = "Top tokens")

In [33]:
prompts = ["The city Delhi belongs to the country of", "The city Mumbai belongs to the country of", "The city Chennai belongs to the country of"
           , "The city Bangalore belongs to the country of", "The city Bombay belongs to the country of"]
answer = 'India'
cities = ["Delhi", "Mumbai", "Chennai", "Bangalore", "Bombay"]
wrong_answers = ["Delhi", "Mumbai", "Chennai", "Bangalore", "Bombay"]

In [63]:
logits,cache = model.run_with_cache(prompts)

In [64]:
answer_token_direction = model.W_U.T[model.to_single_token(answer)]

In [66]:
resid_accumlated, labels = cache.accumulated_resid(return_labels=True)
resid_accumlated.shape

torch.Size([13, 5, 9, 768])

In [69]:
resid_accumlated_in_answer_direction = resid_accumlated@answer_token_direction
resid_accumlated_in_answer_direction.shape

torch.Size([13, 5, 9])

In [68]:
resid_accumlated_in_answer_direction_averaged = resid_accumlated_in_answer_direction.mean(dim=1)

In [70]:
line(resid_accumlated_in_answer_direction_averaged[:,-1], x = labels, title = "Residuals in answer direction")