In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa9e41eb4c0>

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device='cpu')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [11]:
hp12 = "blocks.12.hook_resid_post"
sae12 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp12)
normalise_decoder(sae12, scale_input=False)

In [14]:
top_activations(text_to_sae_feats(model, sae12, hook_point=hp12, text="Anger"))
# top anger is 12312 with act 13.18

(tensor([[[119.2641, 112.2287, 110.4388, 103.3638,  65.7501,  48.8222,  38.8005,
            31.0149,  23.6835,  22.8484],
          [ 61.2354,  15.6910,  15.5455,  15.1585,  14.0710,   8.7777,   8.5807,
             7.8577,   7.7233,   7.5660]]]),
 tensor([[[ 8731, 10841,  5233,  3048,  1645, 11597,  6099,  7650,  6802,  8618],
          [12312, 14039, 10323,  8521,  3081,   883,  8315,  4383,  5503, 11778]]]))

In [15]:
top_activations(text_to_sae_feats(model, sae12, hook_point=hp12, text="Hate"))
# top hate is 4823 with act 15.4

(tensor([[[119.2641, 112.2287, 110.4388, 103.3638,  65.7501,  48.8222,  38.8005,
            31.0149,  23.6835,  22.8484],
          [ 71.4620,  13.4814,  11.1678,  10.1321,   8.4601,   7.9648,   7.3423,
             6.8351,   6.1185,   5.8158]]]),
 tensor([[[ 8731, 10841,  5233,  3048,  1645, 11597,  6099,  7650,  6802,  8618],
          [ 4823, 10323,  8315,  5611,   883,  3081,  1997, 14039, 13485,  2061]]]))

## Feature Vis

In [6]:
activation_store = ActivationsStore.from_config(model, sae12.cfg)

def get_tokens(
    activation_store,
    n_batches_to_sample_from: int = 2**14
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

all_tokens = get_tokens(activation_store).to(device)


100%|██████████| 16384/16384 [13:16<00:00, 20.56it/s] 


In [16]:

model = model.to(device)
sae12 = sae12.to(device)

Moving model to device:  cuda


OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 

In [8]:

n_features = sae12.cfg.d_sae

test_feature_idx_gpt = [12312, 14039, 4823, 10323]
bs = 2

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hp12,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=True,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sae12,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
vis_dir = "feature_vis"
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    try:
        sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    except ZeroDivisionError:
        print(f"Skipped feature {feature} due to ZeroDivisionError.")


Saving feature-centric vis:   0%|          | 0/4 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/4 [00:00<?, ?it/s]