In [18]:
# instantiate an object to hold activations from a dataset
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)


Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.



In [19]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]


# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [
        [f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens
    ]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p - len_prefix) : p])
            if p == tokens.shape[1] - 1:
                suffix = ""
            else:
                suffix = "".join(
                    str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]
                )
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(
        dict(
            str_tokens=list_flatten(str_tokens),
            unique_token=list_flatten(unique_token),
            context=context,
            prompt=prompt,
            pos=pos,
            label=label,
        )
    )

In [20]:
# finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens
feature_list = torch.randint(0, sae.cfg.d_sae, (100,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(
        tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
    )
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

  0%|          | 0/100 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Examples found: 23865: 100%|██████████| 100/100 [00:03<00:00, 27.53it/s]


In [21]:
feature_acts_df = pd.DataFrame(
    all_feature_acts.detach().cpu().numpy(),
    columns=[f"feature_{i}" for i in feature_list],
)
feature_acts_df.shape

(23865, 100)

In [22]:
feature_idx = 0
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][
    :, feature_idx
].detach()
prop_positive_activations = (
    100 * len(all_positive_acts) / (total_batches * batch_size_tokens)
)

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,
)

In [23]:
top_10_activations = feature_acts_df.sort_values(
    f"feature_{feature_list[0]}", ascending=False
).head(10)
all_token_dfs.iloc[
    top_10_activations.index
]  # TODO: double check this is working correctly

Unnamed: 0,str_tokens,unique_token,context,prompt,pos,label,batch
561,7,7/49,<|endoftext|>This article is over| 7| years old\n,4,49,4/49,92
812,even,even/44,\nWhat makes the situation| even| more perplexing,6,44,6/44,65
665,Found,Found/25,David Kelly inquest\n\n|Found| in woods:,5,25,5/25,93
458,quality,quality/74,Criticisms of the| quality| of the study,3,74,3/74,37
220,time,time/92,"by his teammate at the| time|, Ian Del",1,92,1/92,49
703,New,New/63,)\n\nName Current| New| Kenichi Kom,5,63,5/63,8
155,s,s/27,ers.\n\nTurn|s| out The Game,1,27,1/27,17
961,all,all/65,UCLA. If any or| all| of these teams,7,65,7/65,60
463,all,all/79,said Porter expects to have| all| the needed ...,3,79,3/79,98
54,know,know/54,"\nTo those in the| know|, the campaign",0,54,0/54,9


In [24]:
print(f"Shape of the decoder weights {sae.W_dec.shape})")
print(f"Shape of the model unembed {model.W_U.shape}")
projection_matrix = sae.W_dec @ model.W_U
print(f"Shape of the projection matrix {projection_matrix.shape}")

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(
    top_k_tokens.cpu().numpy(), index=[f"feature_{i}" for i in feature_list]
).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: model.tokenizer.decode(x))

Shape of the decoder weights torch.Size([24576, 768]))
Shape of the model unembed torch.Size([768, 50257])
Shape of the projection matrix torch.Size([24576, 50257])



DataFrame.applymap has been deprecated. Use DataFrame.map instead.



Unnamed: 0,feature_5269,feature_24224,feature_16715,feature_14262,feature_21566,feature_19003,feature_11907,feature_4366,feature_22045,feature_3438,...,feature_23475,feature_10831,feature_1336,feature_12483,feature_9142,feature_21585,feature_17427,feature_430,feature_13040,feature_9197
token_0,atform,puberty,Stores,altogether,Technologies,bags,than,raining,Chem,uate,...,namely,resident,sg,LLP,bag,alian,ller,*/,cause,pointer
token_1,TAG,phases,malls,%.,Corporation,happens,efficient,beh,Games,imize,...,viz,natives,adish,subsidiary,bags,enegger,llers,${,IFIED,accuracy
token_2,hiro,inventoryQuantity,stores,'.,CEO,hots,stringent,happen,Gold,pell,...,including,businessman,emonic,LTD,iest,VIDEOS,ptic,</,ifiable,antry
token_3,raq,attrition,shelves,}.,Corp,happening,Than,iner,swers,omething,...,plus,entrepreneur,haw,strives,bike,Jr,vich,"...""",子,hetically
token_4,lement,stages,shoppers,.).,Stores,bag,meaningful,etsk,Site,livion,...,albeit,Times,ionage,founder,sie,Jr,lling,"...""",luck,(%)
token_5,rik,enhagen,store,due,Holdings,iness,aggressive,dawn,Textures,balance,...,totaling,Symphony,shire,partnered,iness,ellen,ls,{},soDeliveryDate,percentage
token_6,hra,odan,Shopping,].,subsidiaries,ious,expensive,impossible,ventory,explanations,...,excluding,native,""":""""},{""",ONSORED,dirt,iciary,letal,%,OTHER,coefficient
token_7,neau,wrench,Walmart,lest,subsidiary,EMS,sophisticated,happ,Fa,urate,...,respectively,bureau,isters,Inc,brush,enhagen,jriwal,"""}",�,pointers
token_8,aml,Perkins,shopping,anymore,executives,uits,HUD,happened,DX,clarification,...,excluding,billionaire,bright,founders,shed,eller,rette,"""))",ification,point
token_9,liam,hoops,Store,"'.""",Company,y,drastic,unclear,elaide,TEXTURE,...,notably,residents,unction,Limited,oxide,uchin,rer,)</,whiff,ttle


In [26]:
# only valid for res-jb resid_pre 7.
# Josh Engel's emailed us these lists.
day_of_the_week_features = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_10th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]

feature_list = day_of_the_week_features

examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(
        tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
    )
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

Examples found: 164: 100%|██████████| 100/100 [00:03<00:00, 28.76it/s]


In [27]:
# do PCA on reconstructions
from sklearn.decomposition import PCA
import plotly.express as px

pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(all_reconstructions.detach().cpu().numpy())

pca_df = pd.DataFrame(pca_embedding, columns=["PC1", "PC2", "PC3"])
pca_df["tokens"] = all_fired_tokens
pca_df["context"] = all_token_dfs.context.values


px.scatter(
    pca_df,
    x="PC2",
    y="PC3",
    hover_data=["context"],
    hover_name="tokens",
    height=800,
    width=1200,
    color="tokens",
    title="PCA Subspace Reconstructions",
).show()