In [1]:
import os
import json
import torch
from IPython.display import HTML, display
from tqdm import tqdm
from datasets import load_dataset
from transformer_lens import HookedTransformer
import transformer_lens.utils as tutils
import matplotlib
import matplotlib.pyplot as plt

from config import SAEConfig
from model import SparseAutoencoder

torch.set_grad_enabled(False)


  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x3065a65b0>

In [2]:
checkpoint_dir = "checkpoints/2024-04-13_12-30-52"  # expansion factor = 8
# checkpoint_dir = "checkpoints/2024-04-13_14-10-45"  # expansion factor = 16

with open(os.path.join(checkpoint_dir, "config.json")) as f:
    cfg = json.load(f)
cfg = SAEConfig(**cfg)
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"

print("expansion factor:", cfg.expansion_factor)

sae = SparseAutoencoder(cfg)
sae.load_state_dict(torch.load(os.path.join(checkpoint_dir, "final_model.pt"), map_location=torch.device(cfg.device)))

model = HookedTransformer.from_pretrained(cfg.model_name, device=cfg.device)

hook_point = cfg.hook_point.format(layer=cfg.hook_point_layer)


Run name: 4096-L1-0.0002-LR-0.0003
expansion factor: 8
Loaded pretrained model gelu-2l into HookedTransformer


In [60]:
# load large sae
large_checkpoint = "checkpoints/2024-04-13_14-10-45"  # expansion factor = 16
with open(os.path.join(large_checkpoint, "config.json")) as f:
    l_cfg = json.load(f)
l_cfg = SAEConfig(**l_cfg)
l_cfg.device = "cuda" if torch.cuda.is_available() else "cpu"

print("expansion factor:", l_cfg.expansion_factor)

large_sae = SparseAutoencoder(l_cfg)
large_sae.load_state_dict(torch.load(os.path.join(large_checkpoint, "final_model.pt"), map_location=torch.device(l_cfg.device)))


Run name: 8192-L1-0.0002-LR-0.0003
expansion factor: 16


<All keys matched successfully>

In [3]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")

tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [62]:
feature_id = 8

def get_feature_acts(feature_id, batch_size, n_batches):
    tokens = all_tokens[:batch_size*n_batches]
    all_acts = []
    for i in tqdm(range(n_batches)):
        batch = tokens[i*batch_size:(i+1)*batch_size]
        _, cache = model.run_with_cache(batch,
                                        stop_at_layer=cfg.hook_point_layer + 1,
                                        names_filter=[hook_point],
                                    )
        in_acts = cache[hook_point]
        in_acts = in_acts.reshape(-1, cfg.d_in)
        sae_out, feature_acts, loss, mse_loss, l1_loss = sae(in_acts)

        feature_acts = feature_acts[:, feature_id].detach().cpu()
        all_acts.append(feature_acts)
        
    all_acts = torch.cat(all_acts, dim=0)
    return all_acts, tokens


feature_acts, tokens = get_feature_acts(feature_id=feature_id, batch_size=256, n_batches=10)
print(feature_acts.shape)

100%|██████████| 10/10 [00:08<00:00,  1.15it/s]

torch.Size([327680])





In [63]:
def single_neuron(tokens: torch.Tensor, acts: torch.Tensor, ft_id, model, top_n=10):
    # taken from my sparse distillation code.
    # tokens and acts shapes are both (n_examples, len_example)
    max_acts = acts.max(dim=1)
    sorted_top_indices = torch.argsort(max_acts.values, descending=True)[:top_n]

    # Process only the top_n activations and corresponding tokens
    snippets = []
    for idx in sorted_top_indices:
        snippet = {}
        example = tokens[idx]
        snippet["text"] = model.to_string(example)
        snippet["max_activation"] = float(acts[idx].max())
        snippet["token_activation_pairs"] = [
            [
                model.to_string(example[j]).replace(" ", "·")
                                                      .replace("\n", "↩" + "\n")
                                                      .replace("\t", "→"),
                float(acts[idx, j])
            ]
            for j in range(example.shape[0])
        ]
        snippets.append(snippet)

    res = {
        "neuron_id": str(ft_id),
        "snippets": snippets
    }
    return res

In [64]:
act_data = single_neuron(tokens=tokens, acts=feature_acts.reshape(-1, 128), ft_id=feature_id, model=model)

def style_snippet(snippet_idx):
    tokens_with_activations = act_data['snippets'][snippet_idx]["token_activation_pairs"]
    max_act = act_data['snippets'][snippet_idx]["max_activation"]
    
    # Function to map activation to color
    def activation_to_color(activation):
        if activation < 0:
            return '#FFFFFF'
        normalized_activation = activation / max_act*0.6
        return plt.cm.Reds(normalized_activation)
    
    styled_text = ''.join(f'<span style="background-color: {matplotlib.colors.rgb2hex(activation_to_color(activation))}; margin-right: 0px;">{token}</span>'
                          for token, activation in tokens_with_activations)
    return styled_text

print("feature id: ", feature_id)
print()
for i, snippet in enumerate(act_data['snippets']):
    styled_text = style_snippet(i)
    snippet_info = f'<div style="word-wrap: break-word; margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
    display(HTML(snippet_info))

feature id:  8



In [65]:
# Direct logit attribution
logit_attributions = sae.W_dec[feature_id] @ model.W_U

top_idxs = torch.argsort(logit_attributions, descending=True)[:10]
bottom_idxs = torch.argsort(logit_attributions, descending=False)[:10]
top_tokens = [model.to_string([idx.item()]) for idx in top_idxs]
bottom_tokens = [model.to_string([idx.item()]) for idx in bottom_idxs]

print(f"feature id: {feature_id} \n")
print('top')
print(top_tokens)
print(logit_attributions[top_idxs])
print('bottom')
print(bottom_tokens)
print(logit_attributions[bottom_idxs])

feature id: 8 

top
['ellt', ' possible', 'inyl', ' offender', ' conceivable', ' feasible', 'oside', 'anos', ' eigenvalue', ' ever']
tensor([0.7993, 0.6731, 0.6686, 0.6465, 0.6358, 0.6242, 0.6115, 0.6061, 0.5860,
        0.5767])
bottom
['ferential', 'station', 'pone', 'imental', 'alty', 'ery', 'apache', 'our', 'ferencing', 'aving']
tensor([-0.7037, -0.6754, -0.6661, -0.6542, -0.6505, -0.6458, -0.6368, -0.6343,
        -0.6175, -0.6100])


In [66]:
def find_cosine_sim(enc_vec, W_enc):
    sims = torch.cosine_similarity(enc_vec.unsqueeze(0), W_enc.T, dim=1)

    top_idxs = torch.argsort(sims, descending=True)[:10]
    print('top_idxs:', top_idxs)
    print('sims:', sims[top_idxs])


find_cosine_sim(sae.W_enc[:, feature_id], sae.W_enc)


top_idxs: tensor([   8, 3701, 1284, 2459,  487, 2172, 3227,  724, 3151, 3759])
sims: tensor([1.0000, 0.6181, 0.6180, 0.6176, 0.6154, 0.6096, 0.6092, 0.6090, 0.6076,
        0.6046])


In [67]:
find_cosine_sim(sae.W_enc[:, feature_id], large_sae.W_enc)

top_idxs: tensor([6532, 5369, 3945,  997,  545, 6111, 2006, 1477, 6590, 2312])
sims: tensor([0.8124, 0.6751, 0.6404, 0.6340, 0.6333, 0.6315, 0.6296, 0.6277, 0.6265,
        0.6253])
