In [None]:
#Set up stuff from notebook
import os
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
print("works")

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh

import plotly.io as pio

pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

device = utils.get_device()

works
Using renderer: notebook_connected


In [221]:
def get_params(model, show_param = False):
    count = 0
    parameters = []
    names = []
    for name, param in model.named_parameters():
        if show_param == True:
            print(name, param.size())
        names.append([name, param.size()])
        parameters.append(param.data)
        count += 1
    print(f"Model has {count} named parameters")
    return parameters, names

def unpack_names(names, first = 0, last = None):
    if last == None:
        for item in names[first:len(names)-1]:
            print(f"Name of param: {item[0]:<30} Dimensions of tensor: {item[1]:<30}")
    else:
        for item in names[first: last]:
            print(f"Name of param: {item[0]:<30} Dimensions of tensor: {str(item[1]):<30}")

def compare(model_1, model_2):
    different = []
    model_1_params, model_1_names = get_params(model_1)
    model_2_params, model_2_names = get_params(model_2)

    if len(model_1_params) != len(model_2_params):
        print("There are a different number of parameters between the models. Please check model config")
        print(f"Model 1 has {len(model_1_params)} parameters. Model 2 has {len(model_2_params)}")

    for i in range(min(len(model_1_params),(len(model_2_params)))):
        if model_1_names[i][1] != model_2_names[i][1]:
            different.append([model_1_names[i], model_2_names[i], ("index", i)])

    print(f"There are {len(different)} parameters that are different from each other")
    return different 


In [200]:
hooked_model = HookedTransformer.from_pretrained("gpt2", device=device)

Loaded pretrained model gpt2 into HookedTransformer


In [201]:
hooked_params, hooked_names = get_params(hooked_model)

Model has 148 named parameters


In [226]:
unpack_names(hooked_names, first = 0, last = 20)

Name of param: embed.W_E                      Dimensions of tensor: torch.Size([50257, 768])      
Name of param: pos_embed.W_pos                Dimensions of tensor: torch.Size([1024, 768])       
Name of param: blocks.0.attn.W_Q              Dimensions of tensor: torch.Size([12, 768, 64])     
Name of param: blocks.0.attn.W_O              Dimensions of tensor: torch.Size([12, 64, 768])     
Name of param: blocks.0.attn.b_Q              Dimensions of tensor: torch.Size([12, 64])          
Name of param: blocks.0.attn.b_O              Dimensions of tensor: torch.Size([768])             
Name of param: blocks.0.attn.W_K              Dimensions of tensor: torch.Size([12, 768, 64])     
Name of param: blocks.0.attn.W_V              Dimensions of tensor: torch.Size([12, 768, 64])     
Name of param: blocks.0.attn.b_K              Dimensions of tensor: torch.Size([12, 64])          
Name of param: blocks.0.attn.b_V              Dimensions of tensor: torch.Size([12, 64])          
Name of pa

In [222]:
hooked_param, hooked_names = get_params(hooked_model)

Model has 148 named parameters


In [254]:
from transformers import GPT2Model

gpt2 = GPT2Model.from_pretrained('gpt2', device_map='auto')


In [255]:
gpt2_params, gpt2_names = get_params(gpt2)

Model has 148 named parameters


In [256]:
gpt2_unpacked = unpack_names(gpt2_names, first = 0, last = 20)

Name of param: wte.weight                     Dimensions of tensor: torch.Size([50257, 768])      
Name of param: wpe.weight                     Dimensions of tensor: torch.Size([1024, 768])       
Name of param: h.0.ln_1.weight                Dimensions of tensor: torch.Size([768])             
Name of param: h.0.ln_1.bias                  Dimensions of tensor: torch.Size([768])             
Name of param: h.0.attn.c_attn.weight         Dimensions of tensor: torch.Size([768, 2304])       
Name of param: h.0.attn.c_attn.bias           Dimensions of tensor: torch.Size([2304])            
Name of param: h.0.attn.c_proj.weight         Dimensions of tensor: torch.Size([768, 768])        
Name of param: h.0.attn.c_proj.bias           Dimensions of tensor: torch.Size([768])             
Name of param: h.0.ln_2.weight                Dimensions of tensor: torch.Size([768])             
Name of param: h.0.ln_2.bias                  Dimensions of tensor: torch.Size([768])             
Name of pa

In [257]:
diff = compare(hooked_model, gpt2)

diff[0:5]

Model has 148 named parameters
Model has 148 named parameters
There are 98 parameters that are different from each other


[[['blocks.0.attn.W_Q', torch.Size([12, 768, 64])],
  ['h.0.ln_1.weight', torch.Size([768])],
  ('index', 2)],
 [['blocks.0.attn.W_O', torch.Size([12, 64, 768])],
  ['h.0.ln_1.bias', torch.Size([768])],
  ('index', 3)],
 [['blocks.0.attn.b_Q', torch.Size([12, 64])],
  ['h.0.attn.c_attn.weight', torch.Size([768, 2304])],
  ('index', 4)],
 [['blocks.0.attn.b_O', torch.Size([768])],
  ['h.0.attn.c_attn.bias', torch.Size([2304])],
  ('index', 5)],
 [['blocks.0.attn.W_K', torch.Size([12, 768, 64])],
  ['h.0.attn.c_proj.weight', torch.Size([768, 768])],
  ('index', 6)]]

In [258]:
hooked_model.blocks[0].attn.W_Q.data

tensor([[[-0.4738, -0.2614, -0.0978,  ...,  0.0908,  0.2785,  0.2262],
         [-0.0604,  0.0430, -0.1627,  ..., -0.1296, -0.1096, -0.1044],
         [ 0.0300,  0.1680,  0.2397,  ...,  0.1302,  0.4627, -0.1095],
         ...,
         [-0.0008,  0.0477,  0.1314,  ...,  0.1046, -0.0898, -0.0920],
         [-0.1142, -0.0996,  0.0593,  ..., -0.1520, -0.1710,  0.1016],
         [-0.1236,  0.1004, -0.1369,  ...,  0.1196, -0.0738, -0.1380]],

        [[-0.1124,  0.1727, -0.0123,  ...,  0.0090, -0.0496,  0.0578],
         [ 0.0471, -0.2233,  0.0557,  ...,  0.1146,  0.4690,  0.1260],
         [-0.0707, -0.0334, -0.0715,  ..., -0.0823, -0.0244,  0.0656],
         ...,
         [ 0.2591, -0.0887,  0.2162,  ..., -0.0083,  0.1477, -0.0473],
         [ 0.1110,  0.0184,  0.0524,  ...,  0.0787, -0.0524,  0.0030],
         [ 0.1317,  0.0518,  0.0714,  ..., -0.2861,  0.4004, -0.2124]],

        [[ 0.1223, -0.1317,  0.0644,  ...,  0.3493,  0.0223,  0.0512],
         [ 0.2970, -0.3978, -0.0329,  ..., -0

In [259]:
replacement_weights = gpt2.h[0].attn.c_attn.weight 

In [260]:
replacement_weights

Parameter containing:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]],
       device='cuda:0', requires_grad=True)

In [261]:
replacement_weights.shape

W_Q, W_K, W_V = replacement_weights.split(768, dim = 1)

print(W_Q.shape, W_K.shape, W_V.shape)

torch.Size([768, 768]) torch.Size([768, 768]) torch.Size([768, 768])


In [262]:
new_wq = W_Q.reshape(12, 768, 64)
new_wk = W_K.reshape(12, 768, 64)
new_wv = W_V.reshape(12, 768, 64)

In [263]:
hooked_model.blocks[0].attn.W_Q.data = new_wq
hooked_model.blocks[0].attn.W_K.data = new_wk
hooked_model.blocks[0].attn.W_V.data = new_wv

In [264]:
hooked_model.blocks[0].attn.W_Q.data

tensor([[[-0.4738, -0.2614, -0.0978,  ...,  0.0908,  0.2785,  0.2262],
         [-0.0604,  0.0430, -0.1627,  ..., -0.1296, -0.1096, -0.1044],
         [ 0.0300,  0.1680,  0.2397,  ...,  0.1302,  0.4627, -0.1095],
         ...,
         [-0.0008,  0.0477,  0.1314,  ...,  0.1046, -0.0898, -0.0920],
         [-0.1142, -0.0996,  0.0593,  ..., -0.1520, -0.1710,  0.1016],
         [-0.1236,  0.1004, -0.1369,  ...,  0.1196, -0.0738, -0.1380]],

        [[-0.1124,  0.1727, -0.0123,  ...,  0.0090, -0.0496,  0.0578],
         [ 0.0471, -0.2233,  0.0557,  ...,  0.1146,  0.4690,  0.1260],
         [-0.0707, -0.0334, -0.0715,  ..., -0.0823, -0.0244,  0.0656],
         ...,
         [ 0.2591, -0.0887,  0.2162,  ..., -0.0083,  0.1477, -0.0473],
         [ 0.1110,  0.0184,  0.0524,  ...,  0.0787, -0.0524,  0.0030],
         [ 0.1317,  0.0518,  0.0714,  ..., -0.2861,  0.4004, -0.2124]],

        [[ 0.1223, -0.1317,  0.0644,  ...,  0.3493,  0.0223,  0.0512],
         [ 0.2970, -0.3978, -0.0329,  ..., -0

In [265]:
model_description_text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""
loss = hooked_model(model_description_text, return_type="loss")
print("Model loss:", loss)

Model loss: tensor(9.6703, device='cuda:0')


In [267]:
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = hooked_model.to_tokens(gpt2_text)
print(gpt2_tokens.device)
gpt2_logits, gpt2_cache = hooked_model.run_with_cache(gpt2_tokens, remove_batch_dim=True)

cuda:0


In [268]:
print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = hooked_model.to_str_tokens(gpt2_text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 33, 33])


In [269]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:
