In [1]:
!uv add sae-lens transformer-lens sae-dashboard pandas plotly tqdm networkx matplotlib seaborn pyvis
import os
import torch
import networkx as nx
from pyvis.network import Network
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from collections import defaultdict
import numpy as np
from tqdm import tqdm
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
%matplotlib inline

[2mResolved [1m181 packages[0m [2min 8ms[0m[0m
[2mAudited [1m165 packages[0m [2min 0.10ms[0m[0m
Device: cuda


In [2]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained_no_processing("gpt2-small", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # <- Release name
    sae_id="blocks.7.hook_resid_pre",  # <- SAE id (not always a hook point!)
    device=device,
)
print(sae.cfg.__dict__)

Loaded pretrained model gpt2-small into HookedTransformer
{'architecture': 'standard', 'd_in': 768, 'd_sae': 24576, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': True, 'finetuning_scaling_factor': False, 'context_size': 128, 'model_name': 'gpt2-small', 'hook_name': 'blocks.7.hook_resid_pre', 'hook_layer': 7, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'Skylion007/openwebtext', 'dataset_trust_remote_code': True, 'normalize_activations': 'none', 'dtype': 'torch.float32', 'device': 'cuda', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}, 'neuronpedia_id': 'gpt2-small/7-res-jb', 'model_from_pretrained_kwargs': {'center_writing_weights': True}, 'seqpos_slice': (None,)}


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

token_dataset

Dataset({
    features: ['tokens'],
    num_rows: 136625
})

In [4]:
import requests

url = "https://www.neuronpedia.org/api/explanation/export?modelId=gpt2-small&saeId=7-res-jb"
headers = {"Content-Type": "application/json"}

response = requests.get(url, headers=headers)
# convert to pandas
data = response.json()
explanations_df = pd.DataFrame(data)
# rename index to "feature"
explanations_df.rename(columns={"index": "feature"}, inplace=True)
# explanations_df["feature"] = explanations_df["feature"].astype(int)
explanations_df["description"] = explanations_df["description"].apply(
    lambda x: x.lower()
)
explanations_df

Unnamed: 0,modelId,layer,feature,description,explanationModelName,typeName
0,gpt2-small,7-res-jb,218,stars and dashed for censoring expletives,,oai_token-act-pair
1,gpt2-small,7-res-jb,218,stars and dashes for censoring expletives,,oai_token-act-pair
2,gpt2-small,7-res-jb,218,offensive language and expletives,,oai_token-act-pair
3,gpt2-small,7-res-jb,2020,names of people,gpt-3.5-turbo,oai_token-act-pair
4,gpt2-small,7-res-jb,3493,references to nazism,gpt-3.5-turbo,oai_token-act-pair
...,...,...,...,...,...,...
24568,gpt2-small,7-res-jb,24571,locations and cities paired with information s...,gpt-3.5-turbo,oai_token-act-pair
24569,gpt2-small,7-res-jb,24572,"actions related to personal grooming, such as ...",gpt-3.5-turbo,oai_token-act-pair
24570,gpt2-small,7-res-jb,24573,"words containing the sequence ""lo""",gpt-3.5-turbo,oai_token-act-pair
24571,gpt2-small,7-res-jb,24574,instances of added or inserted text,gpt-3.5-turbo,oai_token-act-pair


In [5]:
# SAEs don't reconstruct activation perfectly, so if you attach an SAE and want the model to stay performant, you need to use the error term.
# This is because the SAE will be used to modify the forward pass, and if it doesn't reconstruct the activations well, the outputs may be effected.
# Good SAEs have small error terms but it's something to be mindful of.

sae.use_error_term  # If use error term is set to false, we will modify the forward pass by using the sae.

False

In [6]:
# instantiate an object to hold activations from a dataset
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=2048,
    n_batches_in_buffer=16,
    device=device,
)

def list_flatten(nested_list):
    return [x for y in nested_list for x in y]


# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [
        [f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens
    ]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p - len_prefix) : p])
            if p == tokens.shape[1] - 1:
                suffix = ""
            else:
                suffix = "".join(
                    str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]
                )
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(
        dict(
            str_tokens=list_flatten(str_tokens),
            unique_token=list_flatten(unique_token),
            context=context,
            prompt=prompt,
            pos=pos,
            label=label,
        )
    )



In [7]:
def sparse_cat(tensors, dim=0):
    """
    Concatenate a list of (dense or sparse) tensors along `dim`, returning
    a single sparse_coo_tensor.  Never builds a full dense intermediate.
    """
    # device / dtype from first
    device = tensors[0].device
    dtype  = tensors[0].dtype
    
    indices_list = []
    values_list  = []
    offset = 0
    
    for t in tensors:
        # 1) Make sparse COO
        sp = t.to_sparse()         # if t already sparse, this is essentially a no-op
        
        idx = sp.indices().clone() # shape: (ndim, nnz_i)
        vals = sp.values()         # shape: (nnz_i, ...)
        
        # 2) shift the cat‐axis
        idx[dim] += offset
        
        indices_list.append(idx)
        values_list.append(vals)
        
        # 3) bump offset for next chunk
        offset += t.size(dim)
    
    # 4) stitch them together
    all_indices = torch.cat(indices_list, dim=1)  # (ndim, total_nnz)
    all_values  = torch.cat(values_list,  dim=0)  # (total_nnz, ...)
    
    # 5) figure out new size
    out_size = list(tensors[0].shape)
    out_size[dim] = offset
    
    # 6) build final sparse tensor
    out = torch.sparse_coo_tensor(all_indices, all_values, out_size,
                                  dtype=dtype, device=device)
    return out.coalesce()  # optionally coalesce duplicates

In [8]:
len(sae.W_enc)

768

In [None]:
total_batches = 4096
feature_list = torch.randint(0, sae.cfg.d_sae, (total_batches,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

# Original implementation from Nanda would balloon CPU usage, detaching fixes that.
# Updated implementation is only bounded by model + token overhead, I think 
# Only hits 3Gb of VRAM on GPT2 regardless of batch size.
# TODO - Scale on up to full SAE size
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(
        tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
    )
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    #all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list].cpu().detach())
    #all_fired_tokens.append(fired_tokens)
    #all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
#all_token_dfs = pd.concat(all_token_dfs)
#all_fired_tokens = list_flatten(all_fired_tokens)
#all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = sparse_cat(all_feature_acts)

  0%|                               | 0/4096 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Examples found: 246521:   6%| | 243/4096 [00:55<15:05,  4.26i

In [14]:
feature_acts_df = pd.DataFrame(
    all_feature_acts.numpy(),
    columns=[f"feature_{i}" for i in feature_list],
)
feature_acts_df.shape
feature_acts_df

Unnamed: 0,feature_20622,feature_533,feature_4239,feature_22677,feature_3417,feature_5544,feature_1407,feature_15424,feature_9833,feature_11899,...,feature_2945,feature_1371,feature_5991,feature_14771,feature_12620,feature_7546,feature_24393,feature_16096,feature_15380,feature_10788
0,0.0,222.595535,253.000076,0.0,0.0,0.000000,110.669052,154.762756,229.317795,244.764374,...,213.096619,76.259781,224.75563,228.64592,254.42923,50.387886,229.89415,101.386383,0.0,124.366203
1,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
2,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.931584,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
3,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
4,0.0,0.000000,0.000000,0.0,0.0,1.003434,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31298,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
31299,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
31300,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000
31301,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.00000,0.00000,0.00000,0.000000,0.00000,0.000000,0.0,0.000000


In [15]:
from sklearn.covariance import GraphicalLassoCV
gl = GraphicalLassoCV(cv=3, max_iter=100).fit(all_feature_acts.detach().cpu().numpy())

  precision_[indices != idx, idx] = -precision_[idx, idx] * coefs
  precision_[idx, indices != idx] = -precision_[idx, idx] * coefs
  precision_[indices != idx, idx] = -precision_[idx, idx] * coefs
  precision_[idx, indices != idx] = -precision_[idx, idx] * coefs
  precision_[indices != idx, idx] = -precision_[idx, idx] * coefs
  precision_[idx, indices != idx] = -precision_[idx, idx] * coefs
  x = asanyarray(arr - arrmean)


In [16]:
gl.precision_

array([[ 2.81346212e+01,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  9.88045507e-04, -0.00000000e+00, ...,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00],
       [ 0.00000000e+00, -0.00000000e+00,  8.45122338e-04, ...,
        -0.00000000e+00,  0.00000000e+00, -0.00000000e+00],
       ...,
       [ 0.00000000e+00, -0.00000000e+00, -0.00000000e+00, ...,
         3.90858192e-03,  0.00000000e+00, -0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  7.01820370e+00,  0.00000000e+00],
       [ 0.00000000e+00, -0.00000000e+00, -0.00000000e+00, ...,
        -0.00000000e+00,  0.00000000e+00,  2.59431815e-03]])

In [17]:
def precision_to_graph(precision_matrix, feature_names=None, threshold=0.1, trim_names_len=10):
    """Convert precision matrix to networkx graph with feature names"""
    
    # threshold to keep only strong connections
    adj_matrix = np.abs(precision_matrix) > threshold
    np.fill_diagonal(adj_matrix, False)  # remove self-loops
    
    # create networkx graph
    G = nx.from_numpy_array(adj_matrix)
    
    # add feature names as node attributes
    if feature_names is not None:
        print('Setting feature names')
        assert len(feature_names) == precision_matrix.shape[0], \
            f"Expected {precision_matrix.shape[0]} feature names, got {len(feature_names)}"
        
        # add names to nodes
        for i, name in enumerate(feature_names):
            if i in G.nodes():  # only add if node exists (has edges)
                G.nodes[i]['title'] = name[:trim_names_len] + '...' if len(name) > trim_names_len else name
                G.nodes[i]['label'] = name
    
    # add edge weights for visualization
    for i, j in G.edges():
        G[i][j]['weight'] = abs(precision_matrix[i, j]),
        G[i][j]['rel_weight'] = abs(precision_matrix[i, j]) / np.min(precision_matrix[~np.isclose(precision_matrix, 0)])    
    return G

In [18]:
# grab our labels
idxs = [int(x.lstrip('feature_').strip()) for x in feature_acts_df.columns]
labels = list(explanations_df.loc[idxs, 'description'])
G = precision_to_graph(np.abs(gl.precision_), feature_names=labels, threshold=1e-5)
G.remove_nodes_from(list(nx.isolates(G)))
#nx.draw(G)
#plt.show()

Setting feature names


In [21]:
# This graph is hard to read... but we can work with it still
nt = Network('750px', '750px', notebook=True)
nt.from_nx(G, show_edge_weights=False)
nt.toggle_physics(False)
#nt.show('nt.html')



In [22]:
# Let's look at the strongest edges
def print_edge_labels(idx, edge):
    """Print edge labels based on SAE features + weight. Utility printing function"""
    i, j, d = edge
    rel_weight = d['rel_weight']
    print(f"Edge {idx}, Relative Weigth {rel_weight:.2f}:\n\t- {labels[i]}\n\t- {labels[j]}")

ordered_edges=sorted(G.edges(data=True), key=lambda edge: -edge[2].get('rel_weight', 1))
for idx, edge in enumerate(ordered_edges):
    print_edge_labels(idx, edge)

Edge 0, Relative Weigth 1499.03:
	- names and terms related to a popular science fiction tv show
	- phrases related to fox news channel
Edge 1, Relative Weigth 1493.27:
	- phrases related to considering or predicting potential outcomes or results of actions
	- phrases related to fox news channel
Edge 2, Relative Weigth 1484.74:
	- questions and alternatives
	- phrases related to fox news channel
Edge 3, Relative Weigth 1483.95:
	- mentions of catastrophic events or overwhelming situations
	- phrases related to fox news channel
Edge 4, Relative Weigth 1483.57:
	- mentions of the word "pokémon"
	- phrases related to fox news channel
Edge 5, Relative Weigth 1482.30:
	- words and phrases that end in "ly."
	- phrases related to fox news channel
Edge 6, Relative Weigth 1474.62:
	- phrases related to fox news channel
	- specific paired characters that signify a specific programming construct or syntax
Edge 7, Relative Weigth 1474.20:
	- the word "si" in various contexts
	- phrases related to 