# Neuronpedia_followup

Here are two possibly-weird things I noticed when playing with the Bloom SAEs on Neuronpedia:

1. There are 'starts with letter k' and 'starts with letter z' features that propagate down through the layers, starting with layer 1. (0-indexed). Possibly this has an obvious root cause, but I don't currently understand why this would be. Can we automatically identify these features and group them? Possibly this makes the space of features a human must crawl smaller.
2. Lots of the layer-0 features have weird top postive logits: GoldMagikarp (a known tokenizer thing), but also Lumpur, CLASSIFIED, NetMessage. What's going on? 

### Preamble

In [None]:
import json 
import os
import numpy as np
import torch
import plotly.express as px

from datasets import load_dataset
from functools import partial
from pathlib import Path
from transformer_lens import utils
from typing import Dict

from huggingface_hub import hf_hub_download
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
# from sae_lens.analysis.visualizer.data_fns import get_feature_data, FeatureData
from sae_vis.data_fetching_fns import get_feature_data, FeatureData

### Attempting to group and track common features across layers 

In [None]:
pass

### Looking at why these strange tokens appear so frequently

Most notably, CLASSIFIED, NetMessage, largeDownload, GoldMagikarp, Lumpur.

Googling reveals known weirdness with some of these: https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation. Notably with longer versions of these words.

Update - lots of these are indeed 'weird tokens' - natureconservancy, NetMessage, Streamer, GoldMagikarp, largeDownload, NetMessage. But regardless, learning why it's always them coming up will be useful, and they should probably be excluded, for no other reason than because they destroy the autointerps

### Repro of the SAELens notebook

In [None]:
# lifting from https://github.com/jbloomAus/SAELens/blob/main/tutorials/evaluating_your_sae.ipynb

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)


In [None]:
# Download the SAE for layer 0 (the one with the Magikarps)

REPO_ID = "jbloom/GPT2-Small-SAEs"
layer = 0  # any layer from 0 - 11 works here
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"

path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

In [None]:
# We can then load the SAE, dataset and model using the session loader
model, sparse_autoencoders, activation_store = (
    LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)
)

In [None]:
for i, sae in enumerate(sparse_autoencoders):
    hyp = sae.cfg
    print(
        f"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )

In [None]:
# pick which sae you wnat to evaluate. Default is 0
sparse_autoencoder = list(sparse_autoencoders)[0]

In [None]:
#Test the Autoencoder

sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens = activation_store.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)

In [None]:
# Specific Capability Test
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)


def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_point = sparse_autoencoder.cfg.hook_point

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_point,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            hook_point,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
# Generating feature interfaces
vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
from sae_vis.data_config_classes import (
    # ActsHistogramConfig,
    # Column,
    # FeatureTablesConfig,
    SaeVisConfig)

In [None]:
DASHBOARD_FOLDER = 'dashboards'


vocab_dict = model.tokenizer.vocab
vocab_dict = {
    v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset(
    "NeelNanda/c4-code-20k", split="train"
)  # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096 * 5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]


feature_vis_params = SaeVisConfig(
    hook_point=sparse_autoencoder.cfg.hook_point,
    minibatch_size_features=256,
    minibatch_size_tokens=64,
    features=feature_idx,
    verbose=True,
    # feature_centric_layout=layout,
)

# feature_idx=feature_idx,
#     max_batch_size=max_batch_size,
#     left_hand_k=3,
#     buffer=(5, 5),
#     n_groups=10,
#     first_group_size=20,
#     other_groups_size=5,



feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    model=model,
    tokens=tokens,
    cfg=feature_vis_params
)

feature_data.model = model




for test_idx in feature_data.feature_data_dict:
    feature_data.save_feature_centric_vis(
        f"{DASHBOARD_FOLDER}/data_{test_idx:04}.html",
        feature_idx=test_idx,
    )
    # html_str = feature_data[test_idx].get_all_html()
    # with open(f"data_{test_idx:04}.html", "w") as f:
    #     f.write(html_str)

In [None]:

                # for _, feat_index in enumerate(feature_data.feature_data_dict.keys()):
                #     feature = feature_data.feature_data_dict[feat_index]

                #     feature_output = {}
                #     feature_output["featureIndex"] = feat_index

                #     top10_logits = self.round_list(feature.logits_table_data.top_logits)
                #     bottom10_logits = self.round_list(
                #         feature.logits_table_data.bottom_logits
                #     )

                #     # TODO: don't precompute/store these. should do it on the frontend
                #     max_value = max(
                #         np.absolute(bottom10_logits).max(),
                #         np.absolute(top10_logits).max(),
                #     )
                #     neg_bg_values = self.round_list(
                #         np.absolute(bottom10_logits) / max_value
                #     )
                #     pos_bg_values = self.round_list(
                #         np.absolute(top10_logits) / max_value
                #     )
                #     feature_output["neg_bg_values"] = neg_bg_values
                #     feature_output["pos_bg_values"] = pos_bg_values

                #     if feature.feature_tables_data:
                #         feature_output["neuron_alignment_indices"] = (
                #             feature.feature_tables_data.neuron_alignment_indices
                #         )
                #         feature_output["neuron_alignment_values"] = self.round_list(
                #             feature.feature_tables_data.neuron_alignment_values
                #         )
                #         feature_output["neuron_alignment_l1"] = self.round_list(
                #             feature.feature_tables_data.neuron_alignment_l1
                #         )
                #         feature_output["correlated_neurons_indices"] = (
                #             feature.feature_tables_data.correlated_neurons_indices
                #         )
                #         # TODO: this value doesn't exist in the new output type, commenting out for now
                #         # there is a cossim value though - is that what's needed?
                #         # feature_output["correlated_neurons_l1"] = self.round_list(
                #         #     feature.feature_tables_data.correlated_neurons_l1
                #         # )
                #         feature_output["correlated_neurons_pearson"] = self.round_list(
                #             feature.feature_tables_data.correlated_neurons_pearson
                #         )
                #         # feature_output["correlated_features_indices"] = (
                #         #     feature.feature_tables_data.correlated_features_indices
                #         # )
                #         # feature_output["correlated_features_l1"] = self.round_list(
                #         #     feature.feature_tables_data.correlated_features_l1
                #         # )
                #         # feature_output["correlated_features_pearson"] = self.round_list(
                #         #     feature.feature_tables_data.correlated_features_pearson
                #         # )

                #     feature_output["neg_str"] = self.to_str_tokens_safe(
                #         vocab_dict, feature.logits_table_data.bottom_token_ids
                #     )
                #     feature_output["neg_values"] = bottom10_logits
                #     feature_output["pos_str"] = self.to_str_tokens_safe(
                #         vocab_dict, feature.logits_table_data.top_token_ids
                #     )
                #     feature_output["pos_values"] = top10_logits

                #     # TODO: don't know what this should be in the new version
                #     # feature_output["frac_nonzero"] = (
                #     #     feature.middle_plots_data.frac_nonzero
                #     # )

                #     freq_hist_data = feature.acts_histogram_data
                #     freq_bar_values = self.round_list(freq_hist_data.bar_values)
                #     feature_output["freq_hist_data_bar_values"] = freq_bar_values
                #     feature_output["freq_hist_data_tick_vals"] = self.round_list(
                #         freq_hist_data.tick_vals
                #     )

                #     # TODO: don't precompute/store these. should do it on the frontend
                #     freq_bar_values_clipped = [
                #         (0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values)
                #         for v in freq_bar_values
                #     ]
                #     freq_bar_colors = [
                #         colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped
                #     ]
                #     feature_output["freq_hist_data_bar_heights"] = self.round_list(
                #         freq_hist_data.bar_heights
                #     )
                #     feature_output["freq_bar_colors"] = freq_bar_colors

                #     logits_hist_data = feature.logits_histogram_data
                #     feature_output["logits_hist_data_bar_heights"] = self.round_list(
                #         logits_hist_data.bar_heights
                #     )
                #     feature_output["logits_hist_data_bar_values"] = self.round_list(
                #         logits_hist_data.bar_values
                #     )
                #     feature_output["logits_hist_data_tick_vals"] = self.round_list(
                #         logits_hist_data.tick_vals
                #     )

                #     # TODO: check this
                #     feature_output["num_tokens_for_dashboard"] = (
                #         self.n_prompts_to_select
                #     )

                #     activations = []
                #     sdbs = feature.sequence_data
                #     for sgd in sdbs.seq_group_data:
                #         for sd in sgd.seq_data:
                #             if (
                #                 sd.top_token_ids is not None
                #                 and sd.bottom_token_ids is not None
                #                 and sd.top_logits is not None
                #                 and sd.bottom_logits is not None
                #             ):
                #                 activation = {}
                #                 strs = []
                #                 posContribs = []
                #                 negContribs = []
                #                 for i in range(len(sd.token_ids)):
                #                     strs.append(
                #                         self.to_str_tokens_safe(
                #                             vocab_dict, sd.token_ids[i]
                #                         )
                #                     )
                #                     posContrib = {}
                #                     posTokens = [
                #                         self.to_str_tokens_safe(vocab_dict, j)
                #                         for j in sd.top_token_ids[i]
                #                     ]
                #                     if len(posTokens) > 0:
                #                         posContrib["t"] = posTokens
                #                         posContrib["v"] = self.round_list(
                #                             sd.top_logits[i]
                #                         )
                #                     posContribs.append(posContrib)
                #                     negContrib = {}
                #                     negTokens = [
                #                         self.to_str_tokens_safe(vocab_dict, j)  # type: ignore
                #                         for j in sd.bottom_token_ids[i]
                #                     ]
                #                     if len(negTokens) > 0:
                #                         negContrib["t"] = negTokens
                #                         negContrib["v"] = self.round_list(
                #                             sd.bottom_logits[i]
                #                         )
                #                     negContribs.append(negContrib)

                #                 activation["logitContributions"] = json.dumps(
                #                     {"pos": posContribs, "neg": negContribs}
                #                 )
                #                 activation["tokens"] = strs
                #                 activation["values"] = self.round_list(sd.feat_acts)
                #                 activation["maxValue"] = max(activation["values"])
                #                 activation["lossValues"] = self.round_list(
                #                     sd.loss_contribution
                #                 )

                #                 activations.append(activation)
                #     feature_output["activations"] = activations

                #     features_outputs.append(feature_output)

                # json_object = json.dumps(features_outputs, cls=NpEncoder)

                # with open(
                #     f"{self.neuronpedia_folder}/batch-{feature_batch_count}.json", "w"
                # ) as f:
                #     f.write(json_object)