In [1]:
%cd /bilinear-feature-circuits

/bilinear-feature-circuits


In [1]:
import os
os.path.abspath('.')

'/home/tim/bilinear-feature-circuits'

In [2]:
import os
import sys
parent_dir = os.path.abspath('.')
sys.path.append(parent_dir + '/bilinear_interp_tim')
sys.path.append(parent_dir + '/dictionary_learning')
sys.path.append(parent_dir)
import argparse
import gc
import json
import math
from collections import defaultdict
import torch as t
from einops import rearrange
from tqdm import tqdm

from activation_utils import SparseAct
from attribution import patching_effect, jvp
from circuit_plotting import plot_circuit, plot_circuit_posaligned
from dictionary_learning import AutoEncoder
from loading_utils import load_examples, load_examples_nopair
from nnsight import LanguageModel
from language import Transformer, Sight
from sae_adopter import DictionarySAE
from bilinear_circuits_v0 import initialize_model_and_dictionaries
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from typing import Callable
import einops
def get_log_prob_from_resid(resid:t.Tensor, token_id:int):
    return t.nn.functional.log_softmax(model.lm_head(resid), dim = -1)[0,-1,token_id].cpu().item()

In [4]:
device, model, embed, attns, mlps, resids, dictionaries, save_basename, examples, batch_size, num_examples, n_batches, batches = initialize_model_and_dictionaries(
    device='cuda:0',
    model_name="tdooms/fw-nano",
    dict_id='10',  # Note: This was originally an int, but the function expects a string.
    d_model=1024,
    dict_path='tdooms/fw-nano-scope',
    dataset='simple_train',
    num_examples=20,
    example_length=None,
    batch_size=4,
    aggregation='sum',
    nopair=False,
)

Clean and patch inputs of different shapes.
Clean: 3 Patch: 4
Clean and patch inputs of different shapes.
Clean: 4 Patch: 3


In [5]:
input_tensor = t.tensor(model.tokenizer("A trigger is designed to activate a task of the virus, as display ing strange messages, deleting files, sending emails  begin the replicate process or whatever the programmer write in his malicious code.")['input_ids'], device=device).unsqueeze(0)
input_tensor


tensor([[    1,   330,  8366,   349,  5682,   298, 27854,   264,  3638,   302,
           272, 15022, 28725,   390,  4249,  4155,  8708,  8570, 28725, 21750,
           288,  5373, 28725, 10313, 19863, 28705,  2839,   272,   312, 13112,
          1759,   442,  5681,   272,  2007,   794,  3324,   297,   516,  6125,
         10573,  2696, 28723]], device='cuda:0')

In [6]:
with t.no_grad():
    basic_out = model._model.forward(input_ids = input_tensor, labels = input_tensor)
print(basic_out.loss)

tensor(5.3740, device='cuda:0')


In [8]:
#get the per token loss
per_tok_loss = t.log_softmax(basic_out.logits[0,:-1,:], dim = -1).gather(dim = -1, index = input_tensor[0,1:].unsqueeze(-1)).squeeze(-1)
#quick dataframe with what every token is and the next token that the loss is on
df = pd.DataFrame({'current_token': [model.tokenizer.decode(s) for s in input_tensor[0, :-1].cpu().numpy()], 
                   'current_token_id': input_tensor[0, :-1].cpu().numpy(),
                   'next_token': [model.tokenizer.decode(s) for s in input_tensor[0, 1:].cpu().numpy()],
                   'next_token_id': input_tensor[0, 1:].cpu().numpy(),
                     'loss': per_tok_loss.cpu().numpy()})
df


Unnamed: 0,current_token,current_token_id,next_token,next_token_id,loss
0,<s>,1,A,330,-3.425913
1,A,330,trigger,8366,-10.745538
2,trigger,8366,is,349,-1.523019
3,is,349,designed,5682,-7.466188
4,designed,5682,to,298,-0.119676
5,to,298,activate,27854,-3.481383
6,activate,27854,a,264,-1.152388
7,a,264,task,3638,-6.57092
8,task,3638,of,302,-5.88479
9,of,302,the,272,-1.444379


In [9]:
all_submods = [embed] + [submod for layer_submods in zip(mlps, attns, resids) for submod in layer_submods]
def single_tok_logit_metric(tok_ind:int) -> Callable[[LanguageModel],t.Tensor]:
    def metric_fn(model: LanguageModel):
        # Get the logits for the last token in the sequence
        logits = model.lm_head.output[:, -1, :]
        
        # Apply log-softmax to convert logits to log probabilities
        log_probs = t.nn.functional.log_softmax(logits, dim=-1)
        
        # Gather the log probability for the specified token index
        log_prob = t.gather(log_probs, dim=-1, index=t.tensor([tok_ind], device=model.device).view(-1, 1)).squeeze(-1)
        
        return log_prob
    return metric_fn
short_input = input_tensor[:, :5]

In [93]:
def logit_diff_metric(tok1:int, tok2:int) -> Callable[[LanguageModel],t.Tensor]:
    def metric_fn(model: LanguageModel):
        # Get the logits for the last token in the sequence
        logits1 = model.lm_head.output[:, -1, tok1]
        logits2 = model.lm_head.output[:, -1, tok2]
        diff = logits1 - logits2
        diff.save()
        return diff
    return metric_fn

In [10]:
with model.trace(short_input):
    log_loss = single_tok_logit_metric(298)(model)
    log_loss.save()
log_loss

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([-0.1197], device='cuda:0', grad_fn=<SqueezeBackward1>)

In [94]:
t.topk(model.forward(short_input).logits[0,-1], k = 8)

torch.return_types.topk(
values=tensor([13.8895, 10.3930,  9.5242,  9.4697,  9.3849,  9.2240,  8.6763,  8.6177],
       device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([298, 354, 297, 390, 579, 486, 369, 304], device='cuda:0'))

In [95]:
model.tokenizer.decode([298, 354])

'to for'

In [96]:
effects2, deltas2, grads2, total_effect2 = patching_effect(
        clean = short_input,
        patch = None,
        model = model,
        submodules = all_submods,
        dictionaries = dictionaries,
        metric_fn = logit_diff_metric(298, 354),
        metric_kwargs=dict(),
        method='ig' # get better approximations for early layers by using ig
    )



Integrated Gradient estimation


Initial trace


Patching part



In [98]:
all_effects2 = t.empty((1)).to(device)
for i,m in enumerate(all_submods):
    print(f"Model: {i} \n {m}")
    print(t.topk(effects2[m].act[0,-1], k = 31))
    all_effects2 = t.cat((all_effects2, t.topk(effects2[m].act.flatten(), k = 31).values), dim = 0)
    print(effects2[m].resc)
print(t.topk(all_effects2, k = 31))

Model: 0 
 Embedding(32000, 1024)
torch.return_types.topk(
values=tensor([0.1885, 0.0785, 0.0083, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([6113, 2894, 2963,   50,   36,   42,   39,   38,   28,   27,   23,   26,
          30,   35,   37,   34,   15,   14,   10,    8,    0,    6,    3,    4,
          17,   16,   19,   12,   21,   29,   32], device='cuda:0'))
tensor([[[ 1.7470e-04],
         [-1.8090e-03],
         [-1.2685e-01],
         [ 3.0078e-05],
         [-2.6008e-02]]], device='cuda:0')
Model: 1 
 MLP(
  (w): Bilinear(
    in_features=1024, out_features=8192, bias=True
    (gate): Identity()
  )
  (p): Linear(in_features=4096, out_features=1024, bias=True)
)
torch.return_types.topk(
values=tensor([0.3261, 0.1892, 0.1673, 0.1521, 0.1413, 0.11

In [34]:
effects, deltas, grads, total_effect = patching_effect(
        clean = short_input,
        patch = None,
        model = model,
        submodules = all_submods,
        dictionaries = dictionaries,
        metric_fn = single_tok_logit_metric(298),
        metric_kwargs=dict(),
        method='exact' # get better approximations for early layers by using ig
    )


100%|██████████| 25/25 [00:02<00:00,  8.90it/s]
100%|██████████| 150/150 [00:16<00:00,  9.10it/s]
100%|██████████| 150/150 [00:16<00:00,  9.10it/s]
100%|██████████| 150/150 [00:16<00:00,  9.20it/s]
100%|██████████| 150/150 [00:16<00:00,  9.23it/s]
100%|██████████| 150/150 [00:16<00:00,  9.09it/s]
100%|██████████| 150/150 [00:16<00:00,  9.03it/s]
100%|██████████| 150/150 [00:16<00:00,  9.05it/s]
100%|██████████| 150/150 [00:16<00:00,  9.10it/s]
100%|██████████| 150/150 [00:16<00:00,  9.12it/s]
100%|██████████| 150/150 [00:16<00:00,  9.07it/s]
100%|██████████| 150/150 [00:16<00:00,  9.11it/s]
100%|██████████| 150/150 [00:16<00:00,  9.10it/s]


In [53]:
#save effects
t.save(effects[all_submods[0]].act, 'exact_effects_v0.pt')

In [36]:
t.topk(effects[all_submods[-3]].act[0,-1], k = 31)

torch.return_types.topk(
values=tensor([0.0117, 0.0027, 0.0022, 0.0015, 0.0013, 0.0005, 0.0005, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([5700, 6005, 2988, 5433, 1704, 2163, 3046,   23,   12,   22,   21,   20,
          18,   17,   19,   16,    4,    3,    1,    2,    6,    5,    7,    0,
           8,   11,    9,   10,   14,   13,   15], device='cuda:0'))

In [37]:
t.topk(-effects[all_submods[-3]].act[0,-1], k = 31)

torch.return_types.topk(
values=tensor([0.3629, 0.0642, 0.0229, 0.0196, 0.0140, 0.0063, 0.0061, 0.0061, 0.0054,
        0.0053, 0.0053, 0.0052, 0.0044, 0.0035, 0.0033, 0.0023, 0.0023, 0.0018,
        0.0013, 0.0009, 0.0008, 0.0007, 0.0006, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='cuda:0'),
indices=tensor([1049,  780, 2768, 5492,  917, 1655, 8036, 7470, 1815, 2898, 7387, 6890,
        6929, 3215, 7386, 3531, 6009, 5981, 1152, 2182, 3109, 1091, 5652,    7,
           0,    6,    5,    4,    2,    1,    3], device='cuda:0'))

In [None]:
all_effects = t.empty((1)).to(device)
for i,m in enumerate(all_submods):
    print(f"Model: {i} \n {m}")
    print(t.topk(effects[m].act[0,-1], k = 31))
    all_effects = t.cat((all_effects, t.topk(effects[m].act.flatten(), k = 31).values), dim = 0)
    print(effects[m].resc)
print(all_effects)

Model: 0 
 Embedding(32000, 1024)
torch.return_types.topk(
values=tensor([0.0200, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([2894,   29,   27,   28,   24,   23,   25,   26,   18,   17,   15,   16,
          20,   19,   21,   22,   10,    6,    3,    5,    1,    0,    2,    4,
          12,    9,    7,    8,   13,   11,   14], device='cuda:0'))
tensor([[ 1.1742e-05, -7.9274e-06, -1.7366e-02,  1.3538e-05, -3.0152e-03]],
       device='cuda:0')
Model: 1 
 MLP(
  (w): Bilinear(
    in_features=1024, out_features=8192, bias=True
    (gate): Identity()
  )
  (p): Linear(in_features=4096, out_features=1024, bias=True)
)
torch.return_types.topk(
values=tensor([2.3068e-02, 9.4175e-03, 3.0643e-03, 7.7426e-05, 0.0000e+00, 0.0000e+00,
        0.0

In [92]:
t.topk(all_effects, k = 31)

torch.return_types.topk(
values=tensor([0.0451, 0.0441, 0.0273, 0.0260, 0.0241, 0.0231, 0.0230, 0.0226, 0.0224,
        0.0200, 0.0192, 0.0167, 0.0154, 0.0154, 0.0145, 0.0145, 0.0135, 0.0123,
        0.0122, 0.0118, 0.0117, 0.0117, 0.0117, 0.0112, 0.0111, 0.0109, 0.0109,
        0.0108, 0.0107, 0.0106, 0.0106], device='cuda:0'),
indices=tensor([187,   1, 188, 189, 156,  32, 190, 125, 280,   2,  94,  95,  33,  96,
        249, 218, 157, 191,  97,  34, 311, 373,  98, 281,  99, 100, 250, 219,
        101,  63, 192], device='cuda:0'))

In [22]:
with model.trace(short_input):
    final_resid_mid = model._envoy.transformer.h[3].n2.input
    final_resid_mid.save()
    final_mlp_out = model._envoy.transformer.h[3].mlp.output
    final_mlp_out.save()
    final_resid_post = model._envoy.transformer.h[3].output
    final_resid_post.save()
    feature_acts = dictionaries[all_submods[-1]].encode(final_resid_post)
    feature_acts.save()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [42]:
t.topk(feature_acts[0,-1], k = 31)

torch.return_types.topk(
values=tensor([30.5674, 26.3677, 22.0909, 13.3291, 11.0112,  9.4994,  6.2468,  6.1364,
         5.6926,  5.6152,  5.3664,  5.3422,  4.8915,  4.7654,  4.6877,  4.4282,
         4.3245,  4.2570,  4.1588,  4.1180,  3.9214,  3.8790,  3.8148,  3.6669,
         3.6364,  3.5652,  3.5551,  3.5463,  3.5067,  3.4497,  0.0000],
       device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([4257, 8139, 5601, 7929, 5729, 1993, 3448, 4375, 2227, 5050, 5953, 3406,
        6962,  731, 4198, 8135, 5134, 1875, 2081, 4762, 8088,  512, 2775, 3865,
         708,  550, 7194, 2376, 3427, 2959,   11], device='cuda:0'))

In [49]:
a = t.nn.functional.log_softmax(model.lm_head(final_resid_post), dim = -1)[0,-1,298]
a = a.cpu().item()
b = t.nn.functional.log_softmax(model.lm_head(final_resid_post + dictionaries[all_submods[-1]].w_dec.weight.data[:,5729]*0.1), dim = -1)[0,-1,298]
b = b.cpu().item()
(b - a)/0.01


0.0032156705856323242

In [33]:
t.topk(all_effects, k = 31)

torch.return_types.topk(
values=tensor([0.1371, 0.1180, 0.1056, 0.1043, 0.1011, 0.0741, 0.0709, 0.0590, 0.0527,
        0.0515, 0.0509, 0.0488, 0.0475, 0.0470, 0.0466, 0.0447, 0.0427, 0.0389,
        0.0371, 0.0347, 0.0314, 0.0311, 0.0308, 0.0291, 0.0290, 0.0288, 0.0287,
        0.0259, 0.0256, 0.0252, 0.0252], device='cuda:0'),
indices=tensor([187, 373,   1,  32,  94, 188,   2,  33, 280,   3, 189,   4, 311,  63,
        190,  95,   5,  64,  65,   6, 281,  96, 191,   7,  97,  98, 192, 249,
        193,  99, 100], device='cuda:0'))

In [13]:
short_input

tensor([[   1,  330, 8366,  349, 5682]], device='cuda:0')

In [19]:
model.w_u.shape

torch.Size([32000, 1024])

In [24]:
def get_logit_value_at_resid_and_modules(model: LanguageModel, token_id: int, input_tensor: t.Tensor):
    resid_direction = model.w_u[token_id]

    attn_outs = []
    mlp_outs = []
    resid_outs = []
    with model.trace(input_tensor,validate = True, scan = True):
        for layer in range(4):
            attn_out = model._envoy.transformer.h[layer].attn.output
            attn_out = attn_out[0,-1,:]
            attn_out = t.dot(attn_out, resid_direction)
            attn_out.save() #should be shape (batch, seq, d_model)
            mlp_out = model._envoy.transformer.h[layer].mlp.output
            mlp_out = mlp_out[0,-1,:]
            mlp_out = t.dot(mlp_out, resid_direction)
            mlp_out.save()
            resid_out = model._envoy.transformer.h[layer].output
            resid_out = resid_out[0,-1,:]
            resid_out = t.dot(resid_out, resid_direction)
            resid_out.save()
            attn_outs.append(attn_out)
            mlp_outs.append(mlp_out)
            resid_outs.append(resid_out)
    return attn_outs, mlp_outs, resid_outs

attn_outs, mlp_outs, resid_outs = get_logit_value_at_resid_and_modules(model, 298, short_input)

In [25]:
attn_outs

[tensor(0.2589, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(-0.3608, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(0.0997, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(8.5141, device='cuda:0', grad_fn=<DotBackward0>)]

In [26]:
mlp_outs

[tensor(0.6884, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(0.7722, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(0.8675, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(8.4181, device='cuda:0', grad_fn=<DotBackward0>)]

In [27]:
resid_outs

[tensor(0.9138, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(1.5584, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(2.4611, device='cuda:0', grad_fn=<DotBackward0>),
 tensor(13.8895, device='cuda:0', grad_fn=<DotBackward0>)]

In [35]:
model.forward(short_input).logits[0,-1,298]

tensor(13.8895, device='cuda:0', grad_fn=<SelectBackward0>)

In [56]:
def get_log_prob_at_resid(model: LanguageModel, token_id: int, input_tensor: t.Tensor):
    resid_mids = []
    resid_outs = []
    with model.trace(input_tensor,validate = True, scan = True):
        for layer in range(4):
            resid_out = model._envoy.transformer.h[layer].output
            resid_out = t.nn.functional.log_softmax(model.lm_head(resid_out), dim = -1)[0,-1,token_id]
            resid_out.save()
            resid_mid = model._envoy.transformer.h[layer].n2.input
            resid_mid = t.nn.functional.log_softmax(model.lm_head(resid_mid), dim = -1)[0,-1,token_id]
            resid_mid.save()
            resid_mids.append(resid_mid)
            resid_outs.append(resid_out)
    return resid_mids, resid_outs

resid_mids, resid_outs = get_log_prob_at_resid(model, 298, short_input)

In [59]:
for mid, out in zip(resid_mids, resid_outs):
    print(f"Resid mid: {mid}, Resid out: {out}")

Resid mid: tensor(-10.2324, device='cuda:0', grad_fn=<SelectBackward0>), Resid out: tensor(-9.2326, device='cuda:0', grad_fn=<SelectBackward0>)
Resid mid: tensor(-9.4013, device='cuda:0', grad_fn=<SelectBackward0>), Resid out: tensor(-8.3498, device='cuda:0', grad_fn=<SelectBackward0>)
Resid mid: tensor(-8.3223, device='cuda:0', grad_fn=<SelectBackward0>), Resid out: tensor(-6.9268, device='cuda:0', grad_fn=<SelectBackward0>)
Resid mid: tensor(-3.0950, device='cuda:0', grad_fn=<SelectBackward0>), Resid out: tensor(-0.1197, device='cuda:0', grad_fn=<SelectBackward0>)


In [58]:
resid_outs

[tensor(-9.2326, device='cuda:0', grad_fn=<SelectBackward0>),
 tensor(-8.3498, device='cuda:0', grad_fn=<SelectBackward0>),
 tensor(-6.9268, device='cuda:0', grad_fn=<SelectBackward0>),
 tensor(-0.1197, device='cuda:0', grad_fn=<SelectBackward0>)]

In [10]:
fin_mlp_dict = dictionaries[all_submods[-3]]

In [67]:
t.topk(effects[all_submods[-3]].act[0,-1], k = 31)

torch.return_types.topk(
values=tensor([0.0117, 0.0027, 0.0022, 0.0015, 0.0013, 0.0005, 0.0005, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([5700, 6005, 2988, 5433, 1704, 2163, 3046,   23,   12,   22,   21,   20,
          18,   17,   19,   16,    4,    3,    1,    2,    6,    5,    7,    0,
           8,   11,    9,   10,   14,   13,   15], device='cuda:0'))

In [31]:
t.topk(fin_mlp_dict.encode(final_mlp_out)[0,-1], k = 31)

torch.return_types.topk(
values=tensor([29.1835, 22.2516, 15.9731, 14.5302, 13.7950,  8.8346,  7.3819,  5.1124,
         4.6354,  4.5241,  4.1688,  4.1099,  3.8820,  3.6814,  3.4158,  3.1766,
         3.1443,  3.1332,  3.0301,  2.9100,  2.9078,  2.9051,  2.8587,  2.8534,
         2.8001,  2.6894,  2.6483,  2.6281,  2.5836,  2.5582,  0.0000],
       device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152,    4], device='cuda:0'))

In [32]:
t.topk(fin_mlp_dict.encode(final_mlp_out)[0,-1], k = 13)

torch.return_types.topk(
values=tensor([29.1835, 22.2516, 15.9731, 14.5302, 13.7950,  8.8346,  7.3819,  5.1124,
         4.6354,  4.5241,  4.1688,  4.1099,  3.8820], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049], device='cuda:0'))

In [99]:
#Test cosine sim with to unembedding and for unembedding
print("To unembedding")
print(t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[298], fin_mlp_dict.w_dec.weight.data.gather(dim = 1, index = t.tensor([2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152], device = device).view(-1,1)), dim = -1))

To unembedding
tensor([-0.0074, -0.0074,  0.0074, -0.0074, -0.0074, -0.0074, -0.0074, -0.0074,
        -0.0074,  0.0074, -0.0074,  0.0074, -0.0074,  0.0074, -0.0074,  0.0074,
         0.0074, -0.0074, -0.0074, -0.0074,  0.0074,  0.0074,  0.0074,  0.0074,
        -0.0074, -0.0074, -0.0074,  0.0074,  0.0074,  0.0074], device='cuda:0')


In [11]:
fin_mlp_dict.w_dec.weight.data.shape

torch.Size([1024, 8192])

In [16]:
t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[298].unsqueeze(0), fin_mlp_dict.w_dec.weight.data[:,[2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152]].T, dim = -1)

tensor([ 0.0686,  0.0885, -0.0354, -0.0079,  0.1593,  0.1181,  0.0409,  0.1122,
         0.0206,  0.0194,  0.0635,  0.0430,  0.7374,  0.0481,  0.0128,  0.0382,
         0.0151,  0.0308,  0.0126, -0.0035,  0.0295,  0.0028,  0.1120,  0.0443,
         0.0932,  0.0081,  0.1009, -0.0703,  0.1457,  0.0227], device='cuda:0')

In [21]:
t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[298].unsqueeze(0), fin_mlp_dict.w_dec.weight.data[:,1049], dim = -1)

tensor([0.7374], device='cuda:0')

In [19]:
[2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152][12]

1049

In [17]:
t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[354].unsqueeze(0), fin_mlp_dict.w_dec.weight.data[:,[2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152]].T, dim = -1)

tensor([ 0.0611,  0.0848, -0.0342, -0.0031,  0.1825,  0.1306,  0.0373,  0.1209,
         0.0076,  0.0374,  0.0277,  0.0244,  0.1479,  0.0667, -0.0093,  0.0141,
         0.0241,  0.0358, -0.0075,  0.0114,  0.0193, -0.0329,  0.1507,  0.0468,
         0.0579,  0.0038,  0.1324, -0.0667,  0.0776,  0.0106], device='cuda:0')

In [13]:
fin_mlp_dict.w_dec.weight.data[:,2768]

tensor([-0.0164,  0.0111,  0.0240,  ..., -0.0193,  0.0025,  0.0007],
       device='cuda:0')

In [7]:
einops.repeat(t.tensor([2768,  917, 5700, 1815,  780, 6890, 6009, 5433, 6929, 2988, 5981, 7387,
        1049, 1091, 8036, 3046, 2163, 7386, 3215, 1704, 3531, 6005, 7470, 3109,
        1655, 2182, 2898, 5652, 5492, 1152], device = device), "d_active -> d_model d_active", d_model = 1024)

tensor([[2768,  917, 5700,  ..., 5652, 5492, 1152],
        [2768,  917, 5700,  ..., 5652, 5492, 1152],
        [2768,  917, 5700,  ..., 5652, 5492, 1152],
        ...,
        [2768,  917, 5700,  ..., 5652, 5492, 1152],
        [2768,  917, 5700,  ..., 5652, 5492, 1152],
        [2768,  917, 5700,  ..., 5652, 5492, 1152]], device='cuda:0')

In [23]:
mlp_out_error = fin_mlp_dict.forward(final_mlp_out) - final_mlp_out

In [24]:
model.lm_head(mlp_out_error)[0,-1,298]

tensor(-0.7723, device='cuda:0', grad_fn=<SelectBackward0>)

In [25]:
model.lm_head(final_mlp_out)[0,-1,298]

tensor(8.4181, device='cuda:0', grad_fn=<SelectBackward0>)

In [26]:
get_log_prob_from_resid(fin_mlp_dict.forward(final_mlp_out) + final_resid_mid, 298)

-0.27950629591941833

In [83]:
import math
math.exp(-0.2795)

0.7561617278151684

In [88]:
get_log_prob_from_resid(final_resid_post, 298)

-0.1196756660938263

In [33]:
get_log_prob_from_resid(final_resid_post - 3.8820*fin_mlp_dict.w_dec.weight.data[:,1049], 298)

-0.48257333040237427

In [38]:
0.48257333040237427-0.1196756660938263

0.362897664308548

In [90]:
-0.10794062912464142--0.1196756660938263

0.011735036969184875

In [91]:
1+1

2

In [None]:
with model.trace(short_input,validate = True, scan = True):
    attn_out = all_submods[2].output
    attn_out.save()

In [60]:
effects[all_submods[2]].act.shape

torch.Size([1, 5, 8192])

In [54]:
all_submods[2]

Attention(
  (rotary): Rotary()
  (qkv): Linear(in_features=1024, out_features=3072, bias=True)
  (o): Linear(in_features=1024, out_features=1024, bias=False)
  (softmax): Softmax(dim=-1)
)

In [56]:
all_submods[2]

Attention(
  (rotary): Rotary()
  (qkv): Linear(in_features=1024, out_features=3072, bias=True)
  (o): Linear(in_features=1024, out_features=1024, bias=False)
  (softmax): Softmax(dim=-1)
)

In [13]:
with model.trace(short_input):
    attn_out = all_submods[2].output
    attn_out.save()
    attn_features = dictionaries[all_submods[2]].encode(attn_out)
    attn_features.save()
    

In [63]:
t.topk(attn_features[0,-1], k=31)

torch.return_types.topk(
values=tensor([4.7195, 3.8420, 3.6028, 2.7154, 2.5271, 1.9471, 1.0818, 1.0308, 0.8742,
        0.7943, 0.6145, 0.5480, 0.4669, 0.4183, 0.4171, 0.4069, 0.3847, 0.3663,
        0.3576, 0.3503, 0.3502, 0.3388, 0.3180, 0.3176, 0.3152, 0.3129, 0.3127,
        0.3016, 0.3008, 0.2949, 0.0000], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([3514, 6153, 5042,   57,  993, 2764, 3472, 6342, 6874, 6467, 2177, 2452,
        5177, 2891, 5705, 3069, 6496,  693, 1688, 3583, 2440,   64, 7546, 6616,
        6597, 2188, 7805,  617, 8054,  507,   10], device='cuda:0'))

In [15]:
t.topk(effects[all_submods[2]].act[0,-1], k=31)

torch.return_types.topk(
values=tensor([0.0064, 0.0053, 0.0018, 0.0017, 0.0015, 0.0014, 0.0011, 0.0010, 0.0007,
        0.0006, 0.0005, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([6342, 6874,   64, 2188, 7805, 6597, 2440, 7546, 8054, 5705,  693,   34,
          28,   33,   31,   30,    3,    1,    4,    0,    5,   24,   29,   22,
          15,   14,    6,    8,   20,   17,   21], device='cuda:0'))

In [72]:
t.topk(effects[all_submods[2]].act[0,-1], k=31)

torch.return_types.topk(
values=tensor([0.9308, 0.3497, 0.3330, 0.0747, 0.0476, 0.0359, 0.0260, 0.0211, 0.0207,
        0.0183, 0.0147, 0.0147, 0.0107, 0.0100, 0.0093, 0.0040, 0.0039, 0.0027,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0'),
indices=tensor([3514, 5042, 2764, 2891, 6467, 6874, 3583, 6616, 2440, 2188, 3472,   64,
        6496, 8054, 3069,  617, 6342,  507,   25,   24,   20,   15,   23,   22,
           5,    4,    1,    0,    8,   12,   21], device='cuda:0'))

In [241]:
z_act = t.linalg.inv(all_submods[2].o.weight.data) @ dictionaries[all_submods[2]].w_dec.weight.data[:,3514]

In [75]:
dictionaries[all_submods[2]].w_dec.weight.data[:,3514] #Attn feature vector. 


tensor([-0.0168, -0.0135, -0.0028,  ...,  0.0112,  0.0228, -0.0876],
       device='cuda:0')

In [98]:
einops.rearrange(all_submods[2].qkv.weight.data[-1024:,:], 'd_model (n_head d_head) ->n_head d_model d_head', n_head=all_submods[2].config.n_head)

tensor([[[-0.0310, -0.0481, -0.0239,  ..., -0.0088,  0.0255,  0.0045],
         [-0.0375, -0.0211, -0.0199,  ...,  0.0019, -0.0275,  0.0232],
         [-0.0216, -0.0343, -0.0004,  ...,  0.0041, -0.0319, -0.0123],
         ...,
         [ 0.0213, -0.0161,  0.0003,  ..., -0.0366, -0.0450,  0.0189],
         [-0.0372, -0.0171, -0.0287,  ...,  0.0051, -0.0021, -0.0093],
         [-0.0122,  0.0034, -0.0295,  ...,  0.0287, -0.0288, -0.0323]],

        [[ 0.0294,  0.0199,  0.0405,  ..., -0.0482, -0.0402,  0.0106],
         [-0.0006, -0.0392, -0.0516,  ...,  0.0176,  0.0500,  0.0209],
         [-0.0497, -0.0257, -0.0701,  ...,  0.0084,  0.0166, -0.0136],
         ...,
         [ 0.0116,  0.0372, -0.0124,  ..., -0.0397, -0.0293,  0.0179],
         [-0.0245, -0.0056, -0.0351,  ..., -0.0306, -0.0576, -0.0090],
         [ 0.0249, -0.0059,  0.0240,  ..., -0.0290,  0.0144, -0.0305]],

        [[-0.0017,  0.0003, -0.0156,  ...,  0.0133,  0.0138, -0.0217],
         [-0.0744,  0.0037,  0.0246,  ...,  0

In [127]:
Vs = einops.rearrange(all_submods[2].qkv.weight.data[-1024:,:], '(n_head d_head) d_model -> n_head d_head d_model', n_head=all_submods[2].config.n_head) #v
Os = einops.rearrange(all_submods[2].o.weight.data, 'd_model  (n_head d_head) -> n_head d_model d_head', n_head=all_submods[2].config.n_head) #o

In [101]:
from transformer_lens import FactoredMatrix

In [128]:
OV = FactoredMatrix(Os,Vs)

In [129]:
U, S, V = OV.svd()

In [138]:
S

tensor([[3.8220, 3.0563, 2.8632,  ..., 0.3376, 0.3148, 0.2436],
        [4.3806, 2.4779, 2.4094,  ..., 0.3413, 0.2854, 0.2424],
        [2.8725, 2.0972, 2.0136,  ..., 0.2061, 0.1790, 0.1253],
        ...,
        [2.1487, 1.9775, 1.7535,  ..., 0.2125, 0.1810, 0.1493],
        [2.5054, 1.8842, 1.7610,  ..., 0.3061, 0.2746, 0.2129],
        [6.4218, 4.4102, 4.0468,  ..., 0.2739, 0.2599, 0.2520]],
       device='cuda:0')

In [141]:
V.shape

torch.Size([16, 1024, 64])

In [157]:
rand_vec =t.randn(1024, device=device)
rand_vec = rand_vec / t.norm(rand_vec)

In [158]:
Vs = einops.rearrange(V, 'n_head d_model d_head -> (n_head d_head) d_model')
res = einops.einsum(Vs, rand_vec, 'list_of_vs d_model, d_model -> list_of_vs')

In [159]:
print(t.max(res), t.min(res))

tensor(0.0785, device='cuda:0') tensor(-0.0905, device='cuda:0')


In [161]:
real_sae_fet = einops.einsum(Vs, dictionaries[all_submods[2]].w_dec.weight.data[:,3514],'list_of_vs d_model, d_model -> list_of_vs')

In [166]:
t.topk(real_sae_fet, k = 30)

torch.return_types.topk(
values=tensor([0.3156, 0.2092, 0.1717, 0.1622, 0.1616, 0.1595, 0.1456, 0.1349, 0.1293,
        0.1194, 0.1153, 0.1119, 0.1057, 0.1051, 0.1048, 0.1042, 0.1006, 0.0952,
        0.0943, 0.0942, 0.0938, 0.0922, 0.0910, 0.0900, 0.0895, 0.0892, 0.0892,
        0.0880, 0.0869, 0.0866], device='cuda:0'),
indices=tensor([ 833,  769,  708,  902,  836,  766,  843,  712,  838, 1012,  857,  710,
         960,  405,  899,   58,  891,  377,  856,  311,  303,   48,  380,  704,
         302,   18,  742,  810,  597,  835], device='cuda:0'))

In [167]:
t.topk(-real_sae_fet, k = 30)

torch.return_types.topk(
values=tensor([0.2275, 0.2246, 0.1811, 0.1663, 0.1552, 0.1537, 0.1494, 0.1480, 0.1474,
        0.1423, 0.1422, 0.1404, 0.1373, 0.1342, 0.1305, 0.1284, 0.1205, 0.1186,
        0.1168, 0.1126, 0.1116, 0.1112, 0.1102, 0.1086, 0.1070, 0.1048, 0.1046,
        0.1041, 0.1003, 0.0980], device='cuda:0'),
indices=tensor([ 997,  638,  711,   63,  707,  841,  706,  855, 1005,  998,  717,  759,
         995,  321,  850,  746,  849,  714,  983,   37,  767,  848,  770,   40,
         863,  892,  267,  771,  920,  737], device='cuda:0'))

In [163]:
t.norm(dictionaries[all_submods[2]].w_dec.weight.data[:,3514])

tensor(1., device='cuda:0')

In [123]:
next(iter(example.parameters())).shape

torch.Size([20, 10])

In [116]:
example.weight.data @ t.arange(10, dtype=t.float32)

tensor([-4.9988, -1.4216,  2.6791,  2.5656,  2.3612, -4.1847, -2.0328, -3.6314,
        -0.8202,  0.3459, -3.1903, -1.8296,  1.5329,  0.4042,  4.8260,  7.0481,
        -2.3664, -1.9119,  0.7636, -0.9323])

In [117]:
example(t.arange(10, dtype=t.float32))

tensor([-4.9988, -1.4216,  2.6791,  2.5656,  2.3612, -4.1847, -2.0328, -3.6314,
        -0.8202,  0.3459, -3.1903, -1.8296,  1.5329,  0.4042,  4.8260,  7.0481,
        -2.3664, -1.9119,  0.7636, -0.9323], grad_fn=<SqueezeBackward4>)

In [173]:
np.unravel_index(223,(100,10))

(np.int64(22), np.int64(3))

In [179]:
U.shape

torch.Size([16, 1024, 64])

In [171]:
V.shape

torch.Size([16, 1024, 64])

In [192]:
import torch as t
from transformer_lens import FactoredMatrix
import einops
from typing import List, Tuple

def get_top_n_svd_components(
    attention_module: t.nn.Module,
    input_vector: t.Tensor,
    n: int = 10
) -> List[Tuple[float, float, t.Tensor, t.Tensor]]:
    """
    Computes the top n SVD components based on the dot product between the input vector and
    the right singular vectors (V) of the attention module.

    Args:
        attention_module (t.nn.Module): The attention module (e.g., all_submods[2]).
        input_vector (t.Tensor): A tensor of shape (d_model,) representing the input vector.
        n (int, optional): Number of top components to return. Defaults to 10.

    Returns:
        List[Tuple[float, float, t.Tensor, t.Tensor]]: A list of tuples containing:
            - Dot product value (float)
            - Singular value (float)
            - Corresponding U vector (t.Tensor)
            - Corresponding V vector (t.Tensor)
    """
    # Ensure input_vector is normalized
    input_vector = input_vector / input_vector.norm()

    # Extract V from the attention module's qkv weights
    # Assuming the last d_model rows correspond to V
    qkv_weights = attention_module.qkv.weight.data  # Shape: (3*d_model, d_model)
    d_model = qkv_weights.shape[1]
    
    # Extract V (assuming it is the last d_model rows)
    V_weight = qkv_weights[-d_model:, :]  # Shape: (d_model, d_model)

    # Reshape V for multi-head if necessary
    n_head = attention_module.config.n_head

    # Rearrange V to [n_head * d_head, d_model]
    Vs = einops.rearrange(V_weight,  '(n_head d_head) d_model -> n_head d_head d_model', n_head=n_head)

    # Extract O from the attention module's output weights
    O_weight = attention_module.o.weight.data  # Shape: (d_model, d_model)
    
    # Rearrange O to [n_head, d_model, d_head]
    Os = einops.rearrange(O_weight, 'd_model (n_head d_head) -> n_head d_model d_head',
                         n_head=n_head)
    
    # Create FactoredMatrix and perform SVD
    OV = FactoredMatrix(Os, Vs)
    U, S, V = OV.svd()  # U and V are both [n_head, d_model, d_head, ...]
    
    # Compute dot products between input_vector and each V vector
    # V should be of shape [n_head*d_head, d_model]
    projections = einops.einsum(V, input_vector, 'n_head d_model d_head, d_model -> n_head d_head')
    
    # Get top n indices based on absolute dot product
    _, top_indices = t.topk(projections.flatten().abs(), n)
    
    results = []
    for i in range(n):
        loc = np.unravel_index(top_indices[i].cpu(), projections.shape)
        dot_product = projections[loc].item()
        singular_value = S[loc].item()
        U_vector = U[loc[0],:,loc[1]]  # 
        V_vector = V[loc[0],:,loc[1]]  # 
        results.append((loc,dot_product, singular_value, U_vector, V_vector))
    
    return results

In [193]:
top_3514_vecs = get_top_n_svd_components(all_submods[2], dictionaries[all_submods[2]].w_dec.weight.data[:,3514], 10)

  loc = np.unravel_index(top_indices[i].cpu(), projections.shape)


In [194]:
top_3514_vecs

[((np.int64(13), np.int64(1)),
  0.31556427478790283,
  1.9775331020355225,
  tensor([-0.0297,  0.0364,  0.0232,  ..., -0.0123,  0.0067,  0.0146],
         device='cuda:0'),
  tensor([-0.0239, -0.0160,  0.0030,  ...,  0.0086,  0.0003, -0.0415],
         device='cuda:0')),
 ((np.int64(15), np.int64(37)),
  -0.22751836478710175,
  0.6253384351730347,
  tensor([ 0.0114, -0.0282,  0.0504,  ...,  0.0200,  0.0034, -0.0978],
         device='cuda:0'),
  tensor([-0.0381,  0.0093,  0.0063,  ..., -0.0460, -0.0253,  0.0135],
         device='cuda:0')),
 ((np.int64(9), np.int64(62)),
  -0.22455906867980957,
  0.12267374247312546,
  tensor([ 0.0738, -0.0117, -0.0450,  ..., -0.0199, -0.0330,  0.0257],
         device='cuda:0'),
  tensor([-0.0338, -0.0256,  0.0120,  ..., -0.0170,  0.0336,  0.0331],
         device='cuda:0')),
 ((np.int64(12), np.int64(1)),
  0.20916713774204254,
  1.9326283931732178,
  tensor([ 0.0074, -0.0290, -0.0082,  ..., -0.0527,  0.0395,  0.0433],
         device='cuda:0'),
  t

In [232]:
V_no_head = all_submods[2].qkv.weight.data[-1024:,:]
O_no_head = all_submods[2].o.weight.data

In [257]:
A_token_embd = all_submods[0](t.tensor([330]).to(device = 'cuda'))
is_token_embd = all_submods[0](t.tensor([349]).to(device = 'cuda'))
designed_token_embd = all_submods[0](t.tensor([5682]).to(device = 'cuda'))


In [266]:
t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[298], designed_token_embd + is_token_embd, dim = -1)

tensor([0.0238], device='cuda:0', grad_fn=<SumBackward1>)

In [274]:
head_13_out_no_bias = (designed_token_embd + is_token_embd) @ model._model.ov[0,13]

In [278]:
U1, S1, V1 = t.linalg.svd(model._model.ov[0,13])
V1.shape

torch.Size([1024, 1024])

In [286]:
V1[1] @ top_3514_vecs[0][4]

tensor(-1.0001, device='cuda:0', grad_fn=<DotBackward0>)

In [287]:
own_tok = designed_token_embd @ model._model.ov[0,13]

In [288]:
t.nn.functional.cosine_similarity(own_tok, top_3514_vecs[0][4], dim = -1)


tensor([-0.1383], device='cuda:0', grad_fn=<SumBackward1>)

In [289]:
t.nn.functional.cosine_similarity(own_tok, dictionaries[all_submods[2]].w_dec.weight.data[:,3514], dim = -1)

tensor([-0.1244], device='cuda:0', grad_fn=<SumBackward1>)

In [299]:
obv = all_submods[2].o.weight.data @ all_submods[2].qkv.bias[-1024:]
print(t.nn.functional.cosine_similarity(obv, dictionaries[all_submods[2]].w_dec.weight.data[:,3514], dim = -1))

tensor(-0.1116, device='cuda:0', grad_fn=<SumBackward1>)


In [298]:
all_submods[2].o.bias

In [275]:
t.nn.functional.cosine_similarity(dictionaries[all_submods[2]].w_dec.weight.data[:,3514], head_13_out_no_bias, dim = -1)


tensor([-0.0608], device='cuda:0', grad_fn=<SumBackward1>)

In [267]:
t.nn.functional.cosine_similarity(model._model.lm_head.weight.data[298], dictionaries[all_submods[2]].w_dec.weight.data[:,3514], dim = -1)

tensor(0.0147, device='cuda:0')

In [260]:
model._model.ov[0]

tensor([[[-1.1617e-02,  3.7645e-03, -6.8461e-03,  ...,  6.7766e-03,
           6.5290e-03, -7.2235e-03],
         [-1.2089e-02, -9.5461e-03, -1.0197e-02,  ..., -2.8483e-04,
           1.2634e-02,  9.2898e-03],
         [ 1.7853e-02, -6.6746e-03, -1.3770e-02,  ...,  4.4956e-04,
          -7.9831e-03, -4.6051e-03],
         ...,
         [ 3.2585e-03, -7.6406e-03,  6.0812e-03,  ..., -8.3066e-03,
          -1.0763e-02, -3.5473e-03],
         [ 7.3684e-04,  1.6678e-03,  1.0509e-02,  ..., -1.9866e-02,
          -1.3328e-02, -2.2345e-03],
         [-4.0602e-03, -1.6663e-02,  1.2854e-03,  ...,  1.4990e-02,
          -8.2366e-03, -2.1329e-02]],

        [[-3.4120e-03, -8.5110e-03, -5.3302e-04,  ...,  1.0945e-02,
           3.9195e-03,  1.0821e-03],
         [ 1.4866e-02,  1.6028e-03, -8.7989e-03,  ...,  4.8909e-03,
           3.3940e-03, -8.3213e-04],
         [-1.8998e-02, -1.0391e-02, -2.3331e-02,  ..., -1.9615e-04,
          -1.3560e-02, -6.5725e-03],
         ...,
         [ 4.8359e-03,  8

In [261]:
qkv_weights = all_submods[2].qkv.weight.data  # Shape: (3*d_model, d_model)
d_model = qkv_weights.shape[1]

# Extract V (assuming it is the last d_model rows)
V_weight = qkv_weights[-d_model:, :]  # Shape: (d_model, d_model)

# Reshape V for multi-head if necessary
n_head = all_submods[2].config.n_head

# Rearrange V to [n_head * d_head, d_model]
Vs = einops.rearrange(V_weight,  '(n_head d_head) d_model -> n_head d_head d_model', n_head=n_head)

# Extract O from the attention module's output weights
O_weight = all_submods[2].o.weight.data  # Shape: (d_model, d_model)

# Rearrange O to [n_head, d_model, d_head]
Os = einops.rearrange(O_weight, 'd_model (n_head d_head) -> n_head d_model d_head',
                        n_head=n_head)

In [238]:
up_projection = t.transpose(V_no_head @ O_no_head, 1,0) @ dictionaries[all_submods[2]].w_dec.weight.data[:,3514]
t.norm(up_projection)

tensor(1.0857, device='cuda:0')

In [231]:
t.nn.functional.cosine_similarity(attn_out[0,-1], dictionaries[all_submods[2]].w_dec.weight.data[:,3514], dim = -1)

tensor(0.4806, device='cuda:0', grad_fn=<SumBackward1>)

In [230]:
t.nn.functional.cosine_similarity(attn_out[0,-1], top_3514_vecs[3][4], dim = -1)

tensor(0.0814, device='cuda:0', grad_fn=<SumBackward1>)

In [248]:
for i in top_3514_vecs:
    print(f"Index {i[0]} with singular value {i[2]:4f} has cosine similarity {t.nn.functional.cosine_similarity(attn_out[0,-1], i[4], dim = -1):4f} with the attn output and a dot product of {einops.einsum(attn_out[0,-1], i[4], 'd_model, d_model -> '):4f}")

Index (np.int64(13), np.int64(1)) with singular value 1.977533 has cosine similarity 0.192065 with the attn output and a dot product of 1.880131
Index (np.int64(15), np.int64(37)) with singular value 0.625338 has cosine similarity -0.055736 with the attn output and a dot product of -0.545597
Index (np.int64(9), np.int64(62)) with singular value 0.122674 has cosine similarity -0.087071 with the attn output and a dot product of -0.852341
Index (np.int64(12), np.int64(1)) with singular value 1.932628 has cosine similarity 0.081432 with the attn output and a dot product of 0.797134
Index (np.int64(11), np.int64(7)) with singular value 1.366497 has cosine similarity -0.133414 with the attn output and a dot product of -1.305994
Index (np.int64(11), np.int64(4)) with singular value 1.480347 has cosine similarity 0.089717 with the attn output and a dot product of 0.878239
Index (np.int64(0), np.int64(63)) with singular value 0.243602 has cosine similarity -0.103852 with the attn output and a d

In [243]:
rand_vec = t.randn(1024, device=device)
rand_vec = rand_vec / t.norm(rand_vec)

In [244]:
for i in top_3514_vecs:
    print(f"Index {i[0]} with singular value {i[2]} has cosine similarity {t.nn.functional.cosine_similarity(rand_vec, i[4], dim = -1)} with the random vector")

Index (np.int64(13), np.int64(1)) with singular value 1.9775331020355225 has cosine similarity -0.03936593234539032 with the random vector
Index (np.int64(15), np.int64(37)) with singular value 0.6253384351730347 has cosine similarity -0.008046098053455353 with the random vector
Index (np.int64(9), np.int64(62)) with singular value 0.12267374247312546 has cosine similarity -0.021047594025731087 with the random vector
Index (np.int64(12), np.int64(1)) with singular value 1.9326283931732178 has cosine similarity -0.0118798753246665 with the random vector
Index (np.int64(11), np.int64(7)) with singular value 1.3664965629577637 has cosine similarity -0.03248085826635361 with the random vector
Index (np.int64(11), np.int64(4)) with singular value 1.4803467988967896 has cosine similarity -0.04305107146501541 with the random vector
Index (np.int64(0), np.int64(63)) with singular value 0.24360191822052002 has cosine similarity -0.0019873883575201035 with the random vector
Index (np.int64(14), 

In [225]:
t.norm(attn_out[0,-1])

tensor(9.7890, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [196]:
#Looks like heads 13 and 12 are important since there's high dot product and it's second highest singular value. 
#Let's look at the U vector and see which space it reads in from
all_submods[0] top_3514_vecs[0][3] #head 13

tensor([-0.0297,  0.0364,  0.0232,  ..., -0.0123,  0.0067,  0.0146],
       device='cuda:0')

In [221]:
dot_products = einops.einsum(all_submods[0].weight, all_submods[0].weight, 'd_vocab d_model, d_vocab d_model -> d_vocab')
percentiles = t.tensor([
    t.quantile(dot_products, 0.1),
    t.quantile(dot_products, 0.25), 
    t.quantile(dot_products, 0.5),
    t.quantile(dot_products, 0.75),
    t.quantile(dot_products, 0.9),
    t.quantile(dot_products, 0.99)
])
print(f"10th percentile: {percentiles[0]:.2f}")
print(f"25th percentile: {percentiles[1]:.2f}")
print(f"50th percentile: {percentiles[2]:.2f}")
print(f"75th percentile: {percentiles[3]:.2f}") 
print(f"90th percentile: {percentiles[4]:.2f}")
print(f"99th percentile: {percentiles[5]:.2f}")
dot_products

10th percentile: 120.77
25th percentile: 132.34
50th percentile: 141.55
75th percentile: 148.77
90th percentile: 155.27
99th percentile: 167.96


tensor([144.7780,  78.3940, 140.8091,  ..., 147.2960, 137.4702, 143.4475],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [239]:
vocab_dot_products_13 = einops.einsum(all_submods[0].weight, up_projection, 'd_vocab d_model, d_model -> d_vocab')

# Decode the indices and print in a table
token_indices = t.topk(vocab_dot_products_13, k=20).indices
token_values = t.topk(vocab_dot_products_13, k=20).values

# Create a table with token indices and their corresponding values
print("Index\tValue\tToken")
print("-" * 40)
for idx, val in zip(token_indices, token_values):
    print(f"{idx.item()}\t{val.item():.4f} {model.tokenizer.decode([idx.item()])}")



Index	Value	Token
----------------------------------------
1	2.0462 <s>
7914	1.6647 flag
11094	1.5519 ulator
28702	1.4964 FIFA
25771	1.4960 mania
25811	1.4917 ský
15715	1.4910 compare
3393	1.4857 Pr
25712	1.4655 inclu
20877	1.4581 volta
21050	1.4356 Malays
23548	1.4201 miner
21017	1.4116 hö
17837	1.4045 grim
18975	1.3905 projection
5333	1.3812 сле
30276	1.3704 父
31136	1.3326 佛
8143	1.3181 arse
19675	1.3056 Publish


In [195]:
#random vector svd
rand_vec = t.randn(1024, device=device)
rand_vec = rand_vec / t.norm(rand_vec)
top_rand_vec_vecs = get_top_n_svd_components(all_submods[2], rand_vec, 10)
top_rand_vec_vecs

  loc = np.unravel_index(top_indices[i].cpu(), projections.shape)


[((np.int64(3), np.int64(0)),
  0.08837911486625671,
  3.5056588649749756,
  tensor([ 0.0187,  0.0529,  0.0045,  ...,  0.0666, -0.0359,  0.0732],
         device='cuda:0'),
  tensor([ 0.0022, -0.0251,  0.0245,  ..., -0.0024, -0.0252,  0.0026],
         device='cuda:0')),
 ((np.int64(6), np.int64(54)),
  -0.08759850263595581,
  0.28654077649116516,
  tensor([-0.0402,  0.0298,  0.0779,  ...,  0.0210,  0.0125, -0.0350],
         device='cuda:0'),
  tensor([ 0.0266,  0.0147, -0.0592,  ...,  0.0133, -0.0505,  0.0121],
         device='cuda:0')),
 ((np.int64(12), np.int64(0)),
  -0.08639997243881226,
  2.1591038703918457,
  tensor([ 0.0039, -0.0041,  0.0229,  ...,  0.0118,  0.0566,  0.0127],
         device='cuda:0'),
  tensor([-0.0088,  0.0123, -0.0136,  ..., -0.0059, -0.0344, -0.0054],
         device='cuda:0')),
 ((np.int64(4), np.int64(5)),
  -0.08568445593118668,
  1.6612879037857056,
  tensor([-0.0394, -0.0330, -0.0050,  ..., -0.0161,  0.0721,  0.0026],
         device='cuda:0'),
  ten

In [10]:
t.cuda.empty_cache()


clean_inputs = t.cat([e['clean_prefix'] for e in batches[0]], dim=0).to(device)
clean_answer_idxs = t.tensor([e['clean_answer'] for e in batches[0]], dtype=t.long, device=device)

patch_inputs = t.cat([e['patch_prefix'] for e in batches[0]], dim=0).to(device)
patch_answer_idxs = t.tensor([e['patch_answer'] for e in batches[0]], dtype=t.long, device=device)
def metric_fn(model):
    return (
        t.gather(model.lm_head.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
        t.gather(model.lm_head.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
    ) #We're only looking at the logit difference between two answers which is a very limited subset of the model's behavior aya. 

In [11]:
patch_inputs

tensor([[    1,   415,  4531],
        [    1,   415,  3282],
        [    1,   415,  4649],
        [    1,   415, 13500]], device='cuda:0')

In [12]:
clean_inputs

tensor([[   1,  415, 8066],
        [   1,  415, 1832],
        [   1,  415, 6246],
        [   1,  415, 6676]], device='cuda:0')

In [13]:
%%time
effects, deltas, grads, total_effect = patching_effect(
        clean_inputs,
        patch_inputs,
        model,
        all_submods,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(),
        method='ig' # get better approximations for early layers by using ig
    )


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



Integrated Gradient estimation


Initial trace


Patching part

CPU times: user 16.3 s, sys: 424 ms, total: 16.8 s
Wall time: 16.8 s


In [16]:
def unflatten(tensor): # will break if dictionaries vary in size between layers
        b, s, f = effects[resids[0]].act.shape
        unflattened = rearrange(tensor, '(b s x) -> b s x', b=b, s=s)
        return SparseAct(act=unflattened[...,:f], res=unflattened[...,f:])
    
features_by_submod = {
    submod : (effects[submod].to_tensor().flatten().abs() > 0.1).nonzero().flatten().tolist() for submod in all_submods
}

In [27]:
effects[all_submods[0]].act.nonzero().shape

torch.Size([40, 3])

In [31]:
len(features_by_submod[all_submods[0]])

31

In [15]:
effects.keys()

dict_keys([Embedding(32000, 1024), MLP(
  (w): Bilinear(
    in_features=1024, out_features=8192, bias=True
    (gate): Identity()
  )
  (p): Linear(in_features=4096, out_features=1024, bias=True)
), Attention(
  (rotary): Rotary()
  (qkv): Linear(in_features=1024, out_features=3072, bias=True)
  (o): Linear(in_features=1024, out_features=1024, bias=False)
  (softmax): Softmax(dim=-1)
), Layer(
  (attn): Attention(
    (rotary): Rotary()
    (qkv): Linear(in_features=1024, out_features=3072, bias=True)
    (o): Linear(in_features=1024, out_features=1024, bias=False)
    (softmax): Softmax(dim=-1)
  )
  (mlp): MLP(
    (w): Bilinear(
      in_features=1024, out_features=8192, bias=True
      (gate): Identity()
    )
    (p): Linear(in_features=4096, out_features=1024, bias=True)
  )
  (n1): Norm(
    (norm): Identity()
  )
  (n2): Norm(
    (norm): Identity()
  )
), MLP(
  (w): Bilinear(
    in_features=1024, out_features=8192, bias=True
    (gate): Identity()
  )
  (p): Linear(in_featu

In [18]:
t.save(effects[all_submods[9]].act, 'ef9.pt')

In [19]:
t.save(grads[all_submods[9]].act,'ef9.pt')

In [None]:
boing = grads[all_submods[9]] @ deltas[all_submods[9]]

In [48]:
boing.resc

tensor([[[ 0.0000],
         [ 0.0000],
         [-1.3478]],

        [[ 0.0000],
         [ 0.0000],
         [ 1.4131]],

        [[ 0.0000],
         [ 0.0000],
         [-0.0593]],

        [[ 0.0000],
         [ 0.0000],
         [-0.3202]]], device='cuda:0')

In [43]:
boing2 = grads[all_submods[9]].__matmul__(deltas[all_submods[9]])
boing2.res

In [47]:
t.min(boing2.act == boing.act)

tensor(True, device='cuda:0')

In [40]:
grads[all_submods[9]].act.shape

torch.Size([4, 3, 8192])

In [41]:
deltas[all_submods[9]].act.shape

torch.Size([4, 3, 8192])

In [42]:
boing.act.shape

torch.Size([4, 3, 8192])

In [23]:
total_effect

tensor([ 6.7520,  6.3497,  9.2779, 10.3250], device='cuda:0')

In [9]:
effects.keys().__len__()

13

In [18]:
effects[all_submods[9]].act.shape

torch.Size([4, 3, 8192])

In [51]:
effects[all_submods[11]].resc

tensor([[[ 0.0000],
         [ 0.0000],
         [-0.6774]],

        [[ 0.0000],
         [ 0.0000],
         [-0.1120]],

        [[ 0.0000],
         [ 0.0000],
         [-0.0361]],

        [[ 0.0000],
         [ 0.0000],
         [-0.1439]]], device='cuda:0')

In [19]:
effects[all_submods[10]].act.shape

torch.Size([4, 3, 8192])

In [20]:
model._model.config

Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "tdooms/fw-nano",
  "architectures": [
    "Transformer"
  ],
  "attention2": false,
  "bias": true,
  "bilinear": true,
  "d_hidden": 4096,
  "d_model": 1024,
  "gate": null,
  "n_ctx": 512,
  "n_head": 16,
  "n_layer": 4,
  "normalization": false,
  "repo": "tdooms/fw-nano",
  "scale_attn": true,
  "tokenizer": "mistral",
  "torch_dtype": "float32",
  "transformers_version": "4.47.1"
}

In [4]:
cir = t.load('/root/bilinear-feature-circuits/circuits/simple_train_dict10_node0.2_edge0.02_n20_aggsum.pt')

  cir = t.load('/root/bilinear-feature-circuits/circuits/simple_train_dict10_node0.2_edge0.02_n20_aggsum.pt')


In [10]:
cir['nodes']

{'embed': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([1.2417], device='cuda:0')),
 'attn_0': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([-0.2630], device='cuda:0')),
 'mlp_0': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([0.6300], device='cuda:0')),
 'resid_0': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([0.5510], device='cuda:0')),
 'attn_1': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([-0.1944], device='cuda:0')),
 'mlp_1': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([0.7722], device='cuda:0')),
 'resid_1': SparseAct(act=tensor([0.0035, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000], device='cuda:0'), resc=tensor([0.4621], device='cuda:0')),
 'attn_2': SparseAct(act=tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), resc=tensor([0.1127], device='cu

In [11]:
cir['edges']

{'resid_3': {'y': tensor(indices=tensor([[  40,   54,   74,   94,  154,  197,  211,  323,  398,
                           441,  475,  494,  502,  514,  615,  623,  677,  735,
                           792,  839,  866,  891,  931,  987, 1010, 1020, 1062,
                          1108, 1121, 1138, 1162, 1168, 1171, 1191, 1203, 1224,
                          1289, 1387, 1460, 1483, 1488, 1507, 1512, 1514, 1545,
                          1557, 1581, 1588, 1629, 1635, 1647, 1657, 1695, 1727,
                          1786, 1812, 1831, 1842, 1849, 1886, 1909, 1944, 1949,
                          1963, 1975, 2132, 2160, 2192, 2195, 2211, 2215, 2227,
                          2286, 2342, 2373, 2378, 2480, 2489, 2506, 2565, 2611,
                          2616, 2640, 2727, 2730, 2803, 2821, 2984, 3000, 3034,
                          3193, 3204, 3211, 3241, 3272, 3281, 3299, 3306, 3307,
                          3317, 3409, 3512, 3533, 3589, 3642, 3723, 3745, 3821,
                        