In [None]:
import torch
import numpy as np
from sae_lens import SAE, HookedSAETransformer
import pandas as pd
#from sae_lens.toolkit.neuronpedia_integration import get_feature_from_neuronpedia

In [None]:
model = HookedSAETransformer.from_pretrained("gpt2-small", device="cpu")

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = FreivaldsVerificationSAE.from_pretrained(
    release="gpt2-small-res-jb",  # <- Release name
    sae_id="blocks.7.hook_resid_pre",  # <- SAE id (not always a hook point!)
    device="cpu",
)

In [None]:
import torch

prompt = "Today is Sunday, tomorrow is"
tokens = model.to_tokens(prompt, prepend_bos=True)
_, cache = model.run_with_cache_with_saes(
    prompt, saes=[sae]
)
# _, cache = model.run_with_cache_with_saes(
#     tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
# )

# print(cache)

# blocks.7.hook_resid_pre.hook_sae_input
# sae_in = cache[sae.cfg.hook_name + ".hook_sae_acts_post"]  #[1, seq_len, d_in]

# sae_in = cache["blocks.7.hook_resid_pre.hook_sae_input"]
# feature_acts = sae.encode(sae_in).squeeze(0)  #[seq_len, d_sae]

sae_in_encoded = cache["blocks.7.hook_resid_pre.hook_sae_acts_post"]
feature_acts = sae_in_encoded.squeeze(0)  #[seq_len, d_sae]

feature_list = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185] 

feature_acts_sub = feature_acts[:, feature_list]  #[seq_len, len(feature_list)]

# print(sae.W_dec)
# print("Demo2, Feature activation: ", feature_acts.shape)
# print("Demo2, Feature activation dived: ", feature_acts_sub.shape)

def fast_verify(feature_acts_sub, num_trials=10, tol=1e-6):
    seq_len, num_features = feature_acts_sub.shape
    for _ in range(num_trials):
        # Random binary vector (num_features x 1)
        r = torch.randint(0, 2, (num_features, 1), dtype=feature_acts_sub.dtype, device=feature_acts_sub.device)
        print(r)
        # Matrix-vector multiplication: shape [seq_len, 1]
        act_prod = feature_acts_sub @ r
        print(act_prod)
        # If any result is nonzero, a feature is active for at least one token
        if torch.norm(act_prod) > tol:
            return True
    return False

activated = fast_verify(feature_acts_sub)
if activated:
    print("At least one feature in the list is activated by the prompt!!!")
else:
    print("None of the features in the list are activated by the prompt.")



In [None]:
print("W_dec shape:", sae.W_dec.shape)
print("Max feature index:", max(feature_list))
print("d_sae (number of features):", sae.cfg.d_sae)