### This notebook loads the crosscoders by Julian Minder and Clément Dumas and uploads them to Neuronpedia.
### This also uploads the activations associated with each feature.

In [32]:
from huggingface_hub import hf_hub_download
import pandas as pd

NUM_FEATURES_TO_UPLOAD = 100
LAYER_NUM = 13

print("downloading max activating examples from huggingface")
repo_id = "Butanium/max-activating-examples-gemma-2-2b-l13-mu4.1e-02-lr1e-04"
df_path = hf_hub_download(repo_id=repo_id, filename="feature_df.csv", repo_type="dataset")

df = pd.read_csv(df_path, index_col=0)
available_features = df[(df["tag"].isin(["IT only", "Base only"])) & (df["dead"] == False)]
available_features_idx = available_features.index.tolist()

filtered_features = available_features.head(
    NUM_FEATURES_TO_UPLOAD
)
available_features_idx = filtered_features.index.tolist()
print(len(available_features_idx))


downloading max activating examples from huggingface
100
[55, 60, 78, 82, 95, 112, 119, 130, 140, 221, 222, 231, 239, 263, 267, 286, 291, 312, 335, 377, 378, 383, 388, 418, 465, 478, 500, 548, 550, 551, 561, 591, 593, 603, 610, 621, 640, 652, 653, 692, 707, 711, 735, 743, 746, 751, 771, 785, 792, 793, 816, 834, 878, 883, 901, 923, 939, 946, 981, 1019, 1045, 1051, 1056, 1064, 1094, 1105, 1116, 1143, 1157, 1164, 1170, 1187, 1195, 1203, 1248, 1259, 1271, 1277, 1330, 1334, 1353, 1395, 1437, 1462, 1493, 1519, 1523, 1547, 1548, 1570, 1629, 1638, 1649, 1655, 1663, 1665, 1697, 1724, 1742, 1746]
True


In [16]:
# get the max act examples db and tokenizer
from tiny_dashboard import OfflineFeatureCentricDashboard
from nnterp import load_model
import gc

gemma_2_it = load_model("google/gemma-2-2b-it", device_map="cuda")
db_path = hf_hub_download(repo_id=repo_id, filename="chat_base_examples_20.db", repo_type="dataset")
gc.collect()
db = OfflineFeatureCentricDashboard.from_db(db_path, gemma_2_it.tokenizer, column_name="entries")


In [None]:
# get the crosscoders
from neuronpedia.butanium_dictionary_learning.dictionary_learning import CrossCoder

print("getting the crosscoders")
crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True)
print("got the crosscoders")

In [None]:
from neuronpedia.np_vector import NPVector
from neuronpedia.requests.activation_request import Activation

counter = 0
created_np_vectors = []
for feature_idx in available_features_idx:
    counter += 1

    print("Uploading vector for feature", feature_idx)
    print("Progress:", counter, "/", len(available_features_idx))

    # get the weights from the crosscoder and upload it
    crosscoder_weight = crosscoder.encoder.weight[1][:, feature_idx].detach().tolist()

    print("Uploading vector for feature", feature_idx)
    np_vector = NPVector.new(
        label="Crosscoder L13 " + str(feature_idx) + " Dumas/Minder",
        model_id="gemma-2-2b-it",
        layer_num=LAYER_NUM,
        hook_type="hook_resid_pre",
        vector=crosscoder_weight,
        default_steer_strength=20,
    )
    created_np_vectors.append(np_vector)
    # get the associated activations and upload them
    featActs = db.max_activation_examples[feature_idx]
    activationsToUpload: list[Activation] = []
    for act in featActs:
        max_activation_value, tokens, activation_values = act
        activation = Activation(
            tokens=tokens,
            values=activation_values,
        )
        activationsToUpload.append(activation)
    print("Uploading activations for feature", feature_idx)
    np_vector.upload_activations(activationsToUpload)

In [None]:
from neuronpedia.np_list import NPList, NPListItem
from neuronpedia.np_vector import NPVector
import webbrowser

# create a new list
new_list = NPList.new("Crosscoder L13 Dumas/Minder")
print(new_list)
new_list_url = "https://neuronpedia.org/list/" + new_list.id
print(new_list_url)

In [None]:
# turn created np vectors into listitems
list_items = [
    NPListItem(
        model_id=vector.model_id, source=vector.source, index=vector.index, description=vector.label
    )
    for vector in created_np_vectors
]

# batch it up into 100 items at a time for upload
batches = [list_items[i : i + 100] for i in range(0, len(list_items), 100)]

# do the upload
for batch in batches:
    print(batch[0])
    print("Adding batch of", len(batch), "items to the list")
    new_list.add_items(batch)

# print(new_list_url)

webbrowser.open(new_list_url)

In [None]:
np_list = NPList.get(new_list.id)
print(np_list)
