In [7]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
from collections import defaultdict

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2508ca5ed0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
forward_is = torch.load("top_is.pt")
forward_vs = torch.load("top_vs.pt")
fixed_is = []
for row in forward_is:
    if (row == 0).all():
        fixed_is.append(-torch.ones_like(row))
    else:
        fixed_is.append(row)
forward_is = torch.stack(fixed_is)
del fixed_is

backward_is = torch.load("reverse_is.pt")
backward_vs = torch.load("reverse_vs.pt")

In [10]:
def find_act_siblings(act_id):
    acts_on = forward_is[act_id]
    siblings = defaultdict(float)
    for i, act_on in enumerate(acts_on):
        if act_on == -1:
            continue

        in_val = forward_vs[act_id][i]
        for j, sibling in enumerate(backward_is[act_on]):
            if sibling == -1:
                continue
            siblings[sibling.item()] += in_val * backward_vs[act_on][j].item()
        
    siblings = [(k, v) for k, v in siblings.items()]
    siblings.sort(key=lambda x: x[1], reverse=True)
    return siblings


find_act_siblings(10138) # london
    

[(10138, tensor(1.1502)),
 (4593, tensor(1.0046)),
 (10655, tensor(0.9132)),
 (10148, tensor(0.9097)),
 (1852, tensor(0.8904)),
 (15582, tensor(0.8652)),
 (13568, tensor(0.8604)),
 (8305, tensor(0.8113)),
 (1295, tensor(0.8106)),
 (8370, tensor(0.8097)),
 (10286, tensor(0.7876)),
 (2554, tensor(0.7197)),
 (5932, tensor(0.6843)),
 (16247, tensor(0.6818)),
 (3245, tensor(0.6674)),
 (1822, tensor(0.6376)),
 (4479, tensor(0.6095)),
 (13981, tensor(0.5894)),
 (9043, tensor(0.5766)),
 (10567, tensor(0.5431)),
 (10702, tensor(0.5418)),
 (6318, tensor(0.5360)),
 (1178, tensor(0.5339)),
 (5823, tensor(0.5169)),
 (12382, tensor(0.4915)),
 (7339, tensor(0.4712)),
 (8005, tensor(0.4693)),
 (11061, tensor(0.4691)),
 (11786, tensor(0.4606)),
 (13694, tensor(0.4365)),
 (15451, tensor(0.4357)),
 (9235, tensor(0.4338)),
 (4152, tensor(0.4184)),
 (15548, tensor(0.4096)),
 (543, tensor(0.4065)),
 (8516, tensor(0.4057)),
 (68, tensor(0.3763)),
 (536, tensor(0.3748)),
 (13291, tensor(0.3662)),
 (4616, tens

In [11]:
def find_effect_siblings(effect_id):
    parents = backward_is[effect_id]
    siblings = defaultdict(float)

    for i, parent in enumerate(parents):
        if parent == -1:
            continue

        in_val = backward_vs[effect_id][i]
        for j, sibling in enumerate(forward_is[parent]):
            if sibling == -1:
                continue
            siblings[sibling.item()] += in_val * forward_vs[parent][j].item()

    siblings = [(k, v) for k, v in siblings.items()]
    siblings.sort(key=lambda x: x[1], reverse=True)
    return siblings

find_effect_siblings(10138) # london

[(10138, tensor(2.1266)),
 (10148, tensor(2.0577)),
 (8370, tensor(1.8711)),
 (15505, tensor(1.4334)),
 (4152, tensor(1.0667)),
 (4343, tensor(0.8933)),
 (12775, tensor(0.8817)),
 (15231, tensor(0.8519)),
 (14640, tensor(0.8457)),
 (10655, tensor(0.8100)),
 (12090, tensor(0.8000)),
 (3174, tensor(0.6445)),
 (13841, tensor(0.4994)),
 (14112, tensor(0.4897)),
 (4235, tensor(0.4668)),
 (4323, tensor(0.4582)),
 (11057, tensor(0.4471)),
 (11832, tensor(0.4282)),
 (8706, tensor(0.4260)),
 (13282, tensor(0.4253)),
 (11444, tensor(0.4093)),
 (5096, tensor(0.3891)),
 (10207, tensor(0.3793)),
 (9104, tensor(0.3756)),
 (4542, tensor(0.3510)),
 (918, tensor(0.3440)),
 (5168, tensor(0.3272)),
 (2507, tensor(0.3252)),
 (4989, tensor(0.3006)),
 (3985, tensor(0.2933)),
 (7068, tensor(0.2840)),
 (7329, tensor(0.2796)),
 (1889, tensor(0.2507)),
 (13694, tensor(0.2466)),
 (2288, tensor(0.2307)),
 (9320, tensor(0.2277)),
 (11014, tensor(0.2217)),
 (10905, tensor(0.2217)),
 (11376, tensor(0.2113)),
 (13568