In [None]:
#Set up stuff from notebook
import os
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from transformers import AutoTokenizer

DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
print("works")

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh

import plotly.io as pio

pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

device = utils.get_device()

works
Using renderer: notebook_connected


In [3]:
def get_params(model, show_param = False):
    count = 0
    parameters = []
    names = []
    for name, param in model.named_parameters():
        if show_param == True:
            print(name, param.size())
        names.append([name, param.size()])
        parameters.append(param.data)
        count += 1
    print(f"Model has {count} named parameters")
    return parameters, names


def unpack_names(names, first = 0, last = None):
    if last == None:
        for item in names[first:len(names)-1]:
            print(f"Param: {item[0]:<40} Dimensions of tensor: {item[1]:<30}")
    else:
        for item in names[first: last]:
            print(f"Param: {item[0]:<40} Dimensions of tensor: {str(item[1]):<30}")


def compare(model_1, model_2):
    different = []
    model_1_params, model_1_names = get_params(model_1)
    model_2_params, model_2_names = get_params(model_2)

    if len(model_1_params) != len(model_2_params):
        print("There are a different number of parameters between the models. Please check model config")
        print(f"Model 1 has {len(model_1_params)} parameters. Model 2 has {len(model_2_params)}")

    for i in range(min(len(model_1_params),(len(model_2_params)))):
        if model_1_names[i][1] != model_2_names[i][1]:
            different.append([model_1_names[i], model_2_names[i], ("index", i)])

    print(f"There are {len(different)} parameters that are different from each other")
    return different 

In [35]:
from transformers import GPTJForCausalLM        
model_path = r"C:\Users\allan\ResearchStuff\checkpoint-1953"

gptj = GPTJForCausalLM.from_pretrained(model_path)

gptj_params, gptj_names = get_params(gptj)

Model has 125 named parameters


In [None]:
config = {
    "d_model" : 768,
    "d_head" : 64,
    "n_layers" : 12,
    "n_ctx" : 1024,
    "n_heads" : 12,
    "d_mlp" : 3072, 
    "d_vocab" : 50257,
    "act_fn" : "gelu_new", 
    "eps" : 1e-05,
    "normalization_type" : "LN", 
    "positional_embedding_type" : "rotary",
    "rotary_dim": 64,
    "post_embedding_ln": False, 
    "original_architecture" : "gptj",
    "use_normalization_before_and_after" : False
}


hooked = HookedTransformer(config, tokenizer = AutoTokenizer.from_pretrained(model_path))
hooked_params, hooked_names = get_params(hooked)

Model has 197 named parameters


In [6]:
def map_attn(model1 = hooked, model2 = gptj, num_blocks = 12): 
    for i in range(num_blocks):
        # Access state dict
        gptj_state_dict = gptj.state_dict()
        hooked_state_dict = hooked.state_dict()

        gptj_query_weight = f"transformer.h.{i}.attn.q_proj.weight"
        gptj_value_weight = f"transformer.h.{i}.attn.v_proj.weight"
        gptj_key_weight = f"transformer.h.{i}.attn.k_proj.weight"
        gptj_output_weight = f"transformer.h.{i}.attn.out_proj.weight"

        hooked_query_weight = f"blocks.{i}.attn.W_Q"
        hooked_value_weight = f"blocks.{i}.attn.W_V"
        hooked_key_weight = f"blocks.{i}.attn.W_K"
        hooked_output_weight = f"blocks.{i}.attn.W_O"

        # Extract weights
        query_weight_tensor = gptj_state_dict[gptj_query_weight]
        value_weight_tensor = gptj_state_dict[gptj_value_weight]
        key_weight_tensor = gptj_state_dict[gptj_key_weight]
        output_weight_tensor = gptj_state_dict[gptj_output_weight]

        # Reshaping the query, key, and value weights
        new_shape_query = query_weight_tensor.size()[:-1] + (12, 64) 
        query_weight_tensor = query_weight_tensor.view(new_shape_query)

        new_shape_value = value_weight_tensor.size()[:-1] + (12, 64)  
        value_weight_tensor = value_weight_tensor.view(new_shape_value)

        new_shape_key = key_weight_tensor.size()[:-1] + (12, 64)  
        key_weight_tensor = key_weight_tensor.view(new_shape_key)

        # Output reshaping:
        new_shape_output = (12, 64, 768)  # [768, 768] -> [12, 64, 768]
        output_weight_tensor = output_weight_tensor.view(new_shape_output)

        # Loop through each head 
        for j in range(12):  
            hooked_state_dict[hooked_query_weight][j,:,:] = query_weight_tensor[:,j,:]
            hooked_state_dict[hooked_value_weight][j,:,:] = value_weight_tensor[:,j,:]
            hooked_state_dict[hooked_key_weight][j,:,:] = key_weight_tensor[:,j,:]

            # Output has slightly different dimensions
            hooked_state_dict[hooked_output_weight][j,:,:] = output_weight_tensor[j,:,:]

        hooked.load_state_dict(hooked_state_dict)

        print(f"Successfully mapped attention weights for block {i} from GPT-J to Hooked Transformer.")


def set_bias_zero(model1 = hooked, model2 = gptj, num_blocks = 12):
    hooked_state_dict = hooked.state_dict()

    for i in range(12):
        hooked_query_bias = f"blocks.{i}.attn.b_Q"
        hooked_value_bias = f"blocks.{i}.attn.b_V"
        hooked_key_bias = f"blocks.{i}.attn.b_K"
        hooked_output_bias = f"blocks.{i}.attn.b_O"

        hooked_state_dict[hooked_query_bias][i, :] = torch.zeros_like(hooked_state_dict[hooked_query_bias][i, :])
        hooked_state_dict[hooked_value_bias][i, :] = torch.zeros_like(hooked_state_dict[hooked_value_bias][i, :])
        hooked_state_dict[hooked_key_bias][i, :] = torch.zeros_like(hooked_state_dict[hooked_key_bias][i, :])
        hooked_state_dict[hooked_output_bias][:] = torch.zeros_like(hooked_state_dict[hooked_output_bias])

    hooked.load_state_dict(hooked_state_dict)

def map_mlp(model1 = hooked, model2 = gptj, num_blocks = 12):
    for i in range(num_blocks):
        gptj_state_dict = gptj.state_dict()
        hooked_state_dict = hooked.state_dict()

        gptj_mlp_fc_in_weight = f"transformer.h.{i}.mlp.fc_in.weight"
        gptj_mlp_fc_in_bias = f"transformer.h.{i}.mlp.fc_in.bias"
        gptj_mlp_fc_out_weight = f"transformer.h.{i}.mlp.fc_out.weight"
        gptj_mlp_fc_out_bias = f"transformer.h.{i}.mlp.fc_out.bias"

        hooked_mlp_W_in = f"blocks.{i}.mlp.W_in"
        hooked_mlp_b_in = f"blocks.{i}.mlp.b_in"
        hooked_mlp_W_out = f"blocks.{i}.mlp.W_out"
        hooked_mlp_b_out = f"blocks.{i}.mlp.b_out"

        mlp_W_in_tensor = gptj_state_dict[gptj_mlp_fc_in_weight]
        mlp_b_in_tensor = gptj_state_dict[gptj_mlp_fc_in_bias]
        mlp_W_out_tensor = gptj_state_dict[gptj_mlp_fc_out_weight]
        mlp_b_out_tensor = gptj_state_dict[gptj_mlp_fc_out_bias]

        mlp_W_in_tensor = mlp_W_in_tensor.T  
        mlp_W_out_tensor = mlp_W_out_tensor.T 

        hooked_state_dict[hooked_mlp_W_in] = mlp_W_in_tensor
        hooked_state_dict[hooked_mlp_b_in] = mlp_b_in_tensor
        hooked_state_dict[hooked_mlp_W_out] = mlp_W_out_tensor
        hooked_state_dict[hooked_mlp_b_out] = mlp_b_out_tensor
        
        hooked.load_state_dict(hooked_state_dict)
        print(f"Successfully mapped MLP weights and biases for block {i} from GPT-J to Hooked Transformer.")

def map_embed(model1=hooked, model2=gptj):
    gptj_state_dict = gptj.state_dict()
    hooked_state_dict = hooked.state_dict()

    gptj_wte_weight = "transformer.wte.weight"  
    gptj_lm_head_bias = "lm_head.bias"         
    
    hooked_embed_W_E = "embed.W_E"             
    hooked_unembed_W_U = "unembed.W_U"         
    hooked_unembed_b_U = "unembed.b_U"         

    # Extract the weights and biases from GPT-J
    wte_weight_tensor = gptj_state_dict[gptj_wte_weight]  # Shape: [50257, 768]
    lm_head_bias_tensor = gptj_state_dict[gptj_lm_head_bias]  # Shape: [50257]
    
    hooked_state_dict[hooked_embed_W_E] = wte_weight_tensor  # [50257, 768] -> [50257, 768]
    hooked_state_dict[hooked_unembed_W_U] = wte_weight_tensor.T  # [50257, 768] -> [768, 50257]

    hooked_state_dict[hooked_unembed_b_U] = lm_head_bias_tensor  # [50257] -> [50257]

    hooked.load_state_dict(hooked_state_dict)
    print("Successfully mapped embedding/unembedding weights and biases from GPT-J to Hooked Transformer.")


def map_ln_params(model1 = hooked, model2 = gptj, num_blocks = 12): # Model 2's ln params will be transferred to Model 1
    for i in range(num_blocks):
        gptj_state_dict = gptj.state_dict()
        hooked_state_dict = hooked.state_dict()

        gptj_ln_1_w = f"transformer.h.{i}.ln_1.weight"
        gptj_ln_1_b = f"transformer.h.{i}.ln_1.bias"

        hooked_ln1_w = f"blocks.{i}.ln1.w"
        hooked_ln1_b = f"blocks.{i}.ln1.b"
        hooked_ln2_w = f"blocks.{i}.ln2.w"
        hooked_ln2_b = f"blocks.{i}.ln2.b"

        if gptj_ln_1_w in gptj_state_dict and gptj_ln_1_b in gptj_state_dict:
            # Extract the weight and bias from GPT-J
            ln1_w_tensor = gptj_state_dict[gptj_ln_1_w]
            ln1_b_tensor = gptj_state_dict[gptj_ln_1_b]

            # Now map them to the Hooked Transformer model
            if hooked_ln1_w in hooked_state_dict:
                print(f"Copying {gptj_ln_1_w} to {hooked_ln1_w}")
                hooked_state_dict[hooked_ln1_w] = ln1_w_tensor
            else:
                print(f"{hooked_ln1_w} not found in Hooked Transformer.")

            if hooked_ln1_b in hooked_state_dict:
                print(f"Copying {gptj_ln_1_b} to {hooked_ln1_b}")
                hooked_state_dict[hooked_ln1_b] = ln1_b_tensor
            else:
                print(f"{hooked_ln1_b} not found in Hooked Transformer.")

            # Now copy the same weights to ln2.w and ln2.b in the Hooked Transformer
            if hooked_ln2_w in hooked_state_dict:
                print(f"Copying {gptj_ln_1_w} to {hooked_ln2_w}")
                hooked_state_dict[hooked_ln2_w] = ln1_w_tensor  # Same weight for ln2
            else:
                print(f"{hooked_ln2_w} not found in Hooked Transformer.")

            if hooked_ln2_b in hooked_state_dict:
                print(f"Copying {gptj_ln_1_b} to {hooked_ln2_b}")
                hooked_state_dict[hooked_ln2_b] = ln1_b_tensor  # Same bias for ln2
            else:
                print(f"{hooked_ln2_b} not found in Hooked Transformer.")
        else:
            print(f"Missing {gptj_ln_1_w} or {gptj_ln_1_b} in GPT-J state dict.")

        hooked.load_state_dict(hooked_state_dict)
        print(f"Successfully updated Hooked Transformer with GPT-J layer normalization parameters for block {i}.")

    

In [7]:
def swap_param():
    #Swaps hooked model's linear nomalization parameters with custom
    map_ln_params()

    # Sets bias of Q, V, K, O to 0 
    set_bias_zero()
    
    #Swaps hooked model's Q, V, K, O weights with custom 
    map_attn()

    #Swaps hooked model's MLP weights+biases with custom 
    map_mlp()

    # Swaps embedding, unembed 
    map_embed()

In [8]:
swap_param()

Copying transformer.h.0.ln_1.weight to blocks.0.ln1.w
Copying transformer.h.0.ln_1.bias to blocks.0.ln1.b
Copying transformer.h.0.ln_1.weight to blocks.0.ln2.w
Copying transformer.h.0.ln_1.bias to blocks.0.ln2.b
Successfully updated Hooked Transformer with GPT-J layer normalization parameters for block 0.
Copying transformer.h.1.ln_1.weight to blocks.1.ln1.w
Copying transformer.h.1.ln_1.bias to blocks.1.ln1.b
Copying transformer.h.1.ln_1.weight to blocks.1.ln2.w
Copying transformer.h.1.ln_1.bias to blocks.1.ln2.b
Successfully updated Hooked Transformer with GPT-J layer normalization parameters for block 1.
Copying transformer.h.2.ln_1.weight to blocks.2.ln1.w
Copying transformer.h.2.ln_1.bias to blocks.2.ln1.b
Copying transformer.h.2.ln_1.weight to blocks.2.ln2.w
Copying transformer.h.2.ln_1.bias to blocks.2.ln2.b
Successfully updated Hooked Transformer with GPT-J layer normalization parameters for block 2.
Copying transformer.h.3.ln_1.weight to blocks.3.ln1.w
Copying transformer.h.3.l

In [10]:
model_description_text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""
loss = hooked(model_description_text, return_type="loss")
print("Model loss:", loss)

Model loss: tensor(10.9698, device='cuda:0')


In [13]:
text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
tokens = hooked.to_tokens(text)
print(tokens.device)
logits, cache = hooked.run_with_cache(tokens, remove_batch_dim=True)

cuda:0


In [16]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)
str_tokens = hooked.to_str_tokens(text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 175, 175])


In [17]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [38]:
params, names = get_params(hooked)

Model has 197 named parameters


In [55]:
unpack_names(names, 0, 200)

Param: embed.W_E                                Dimensions of tensor: torch.Size([50257, 768])      
Param: blocks.0.ln1.w                           Dimensions of tensor: torch.Size([768])             
Param: blocks.0.ln1.b                           Dimensions of tensor: torch.Size([768])             
Param: blocks.0.ln2.w                           Dimensions of tensor: torch.Size([768])             
Param: blocks.0.ln2.b                           Dimensions of tensor: torch.Size([768])             
Param: blocks.0.attn.W_Q                        Dimensions of tensor: torch.Size([12, 768, 64])     
Param: blocks.0.attn.W_O                        Dimensions of tensor: torch.Size([12, 64, 768])     
Param: blocks.0.attn.b_Q                        Dimensions of tensor: torch.Size([12, 64])          
Param: blocks.0.attn.b_O                        Dimensions of tensor: torch.Size([768])             
Param: blocks.0.attn.W_K                        Dimensions of tensor: torch.Size([12, 768, 

In [57]:
unpack_names(gptj_names, 0, 200)

Param: transformer.wte.weight                   Dimensions of tensor: torch.Size([50257, 768])      
Param: transformer.h.0.ln_1.weight              Dimensions of tensor: torch.Size([768])             
Param: transformer.h.0.ln_1.bias                Dimensions of tensor: torch.Size([768])             
Param: transformer.h.0.attn.k_proj.weight       Dimensions of tensor: torch.Size([768, 768])        
Param: transformer.h.0.attn.v_proj.weight       Dimensions of tensor: torch.Size([768, 768])        
Param: transformer.h.0.attn.q_proj.weight       Dimensions of tensor: torch.Size([768, 768])        
Param: transformer.h.0.attn.out_proj.weight     Dimensions of tensor: torch.Size([768, 768])        
Param: transformer.h.0.mlp.fc_in.weight         Dimensions of tensor: torch.Size([3072, 768])       
Param: transformer.h.0.mlp.fc_in.bias           Dimensions of tensor: torch.Size([3072])            
Param: transformer.h.0.mlp.fc_out.weight        Dimensions of tensor: torch.Size([768, 3072

In [69]:
-1.7344 * 0.0001

-0.00017344

In [92]:
hooked.blocks[0].attn.W_K.data[0]


tensor([[ 0.0411, -0.0688, -0.0819,  ..., -0.0928, -0.0172, -0.0982],
        [-0.0947, -0.0464,  0.0336,  ..., -0.0491, -0.0585,  0.0484],
        [-0.0185, -0.0393, -0.0173,  ...,  0.1301,  0.0051, -0.0490],
        ...,
        [ 0.0035, -0.0087, -0.0043,  ...,  0.0111,  0.0089, -0.0068],
        [-0.0285, -0.0108, -0.0139,  ...,  0.0233, -0.0123,  0.0299],
        [-0.0060,  0.0138, -0.0060,  ...,  0.0194, -0.0118,  0.0152]],
       device='cuda:0')

In [93]:
gptj.transformer.h[0].attn.k_proj.weight.data[0][0:64]

tensor([ 0.0411, -0.0688, -0.0819,  0.0669, -0.1228, -0.0290, -0.0004,  0.0142,
        -0.0698,  0.0224,  0.0902,  0.0137, -0.0868,  0.0609, -0.0042,  0.0543,
         0.0081, -0.1628, -0.0107, -0.1196, -0.0034,  0.0392,  0.0255,  0.0321,
        -0.1227, -0.0702, -0.0450,  0.0052,  0.0031,  0.0773,  0.0385,  0.0593,
        -0.0033, -0.0443,  0.0430, -0.0192, -0.0492,  0.1235,  0.0199,  0.0798,
         0.1092, -0.1013,  0.0565,  0.1020,  0.0730, -0.0544,  0.1441,  0.1504,
        -0.0346,  0.1451, -0.0089, -0.0121, -0.0126,  0.0952, -0.1000,  0.0031,
         0.0482, -0.1172, -0.1256, -0.0604, -0.0775, -0.0928, -0.0172, -0.0982])

In [43]:
gptj_names

[['transformer.wte.weight', torch.Size([50257, 768])],
 ['transformer.h.0.ln_1.weight', torch.Size([768])],
 ['transformer.h.0.ln_1.bias', torch.Size([768])],
 ['transformer.h.0.attn.k_proj.weight', torch.Size([768, 768])],
 ['transformer.h.0.attn.v_proj.weight', torch.Size([768, 768])],
 ['transformer.h.0.attn.q_proj.weight', torch.Size([768, 768])],
 ['transformer.h.0.attn.out_proj.weight', torch.Size([768, 768])],
 ['transformer.h.0.mlp.fc_in.weight', torch.Size([3072, 768])],
 ['transformer.h.0.mlp.fc_in.bias', torch.Size([3072])],
 ['transformer.h.0.mlp.fc_out.weight', torch.Size([768, 3072])],
 ['transformer.h.0.mlp.fc_out.bias', torch.Size([768])],
 ['transformer.h.1.ln_1.weight', torch.Size([768])],
 ['transformer.h.1.ln_1.bias', torch.Size([768])],
 ['transformer.h.1.attn.k_proj.weight', torch.Size([768, 768])],
 ['transformer.h.1.attn.v_proj.weight', torch.Size([768, 768])],
 ['transformer.h.1.attn.q_proj.weight', torch.Size([768, 768])],
 ['transformer.h.1.attn.out_proj.weig