In [1]:
import os
import json
import torch
from IPython.display import HTML, display
from tqdm import tqdm
from datasets import load_dataset
from transformer_lens import HookedTransformer
import transformer_lens.utils as tutils
import matplotlib
import matplotlib.pyplot as plt

from config import SAEConfig
from model import SparseAutoencoder



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint_dir = "checkpoints/2024-04-13_12-30-52"
with open(os.path.join(checkpoint_dir, "config.json")) as f:
    cfg = json.load(f)
cfg = SAEConfig(**cfg)
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"

sae = SparseAutoencoder(cfg)
sae.load_state_dict(torch.load(os.path.join(checkpoint_dir, "final_model.pt"), map_location=torch.device(cfg.device)))

model = HookedTransformer.from_pretrained(cfg.model_name, device=cfg.device)

hook_point = cfg.hook_point.format(layer=cfg.hook_point_layer)


Run name: 4096-L1-0.0002-LR-0.0003
device: cpu
dtype: torch.float32
Loaded pretrained model gelu-2l into HookedTransformer


In [3]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")

tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
feature_id = 0

def get_feature_acts(feature_id, batch_size, n_batches):
    tokens = all_tokens[:batch_size*n_batches]
    all_acts = []
    for i in tqdm(range(n_batches)):
        batch = tokens[i*batch_size:(i+1)*batch_size]
        _, cache = model.run_with_cache(batch,
                                        stop_at_layer=cfg.hook_point_layer + 1,
                                        names_filter=[hook_point],
                                    )
        in_acts = cache[hook_point]
        in_acts = in_acts.reshape(-1, cfg.d_in)
        sae_out, feature_acts, loss, mse_loss, l1_loss = sae(in_acts)

        feature_acts = feature_acts[:, feature_id].detach().cpu()
        all_acts.append(feature_acts)
        
    all_acts = torch.cat(all_acts, dim=0)
    return all_acts, tokens


feature_acts, tokens = get_feature_acts(feature_id=feature_id, batch_size=256, n_batches=10)
print(feature_acts.shape)

100%|██████████| 10/10 [00:08<00:00,  1.13it/s]

torch.Size([327680])





In [11]:
def single_neuron(tokens: torch.Tensor, acts: torch.Tensor, ft_id, original_model, top_n=10):
    # taken from my sparse distillation code.

    # tokens shape is (n_examples, len_example)
    # acts shape is (n_examples, len_example)
    space = "·"
    newline="↩"
    tab = "→"

    # get the indexes of examples with the highest max activation
    max_acts = acts.max(dim=1)
    sorted_top_indices = torch.argsort(max_acts.values, descending=True)[:top_n]

    # Process only the top_n activations and corresponding tokens
    snippets = []
    for idx in sorted_top_indices:
        snippet = {}
        example = tokens[idx]
        snippet["text"] = original_model.to_string(example)
        snippet["max_activation"] = float(acts[idx].max())
        snippet["token_activation_pairs"] = [
            [
                original_model.to_string(example[j]).replace(" ", space)
                                                      .replace("\n", newline + "\n")
                                                      .replace("\t", tab),
                float(acts[idx, j])
            ]
            for j in range(example.shape[0])
        ]
        snippets.append(snippet)

    res = {
        "neuron_id": str(ft_id),
        "snippets": snippets
    }
    return res

In [12]:
act_data = single_neuron(tokens=tokens, acts=feature_acts.reshape(-1, 128), ft_id=feature_id, original_model=model)

def style_snippet(snippet_idx):
    tokens_with_activations = act_data['snippets'][snippet_idx]["token_activation_pairs"]
    max_act = act_data['snippets'][snippet_idx]["max_activation"]
    
    # Function to map activation to color
    def activation_to_color(activation):
        if activation < 0:
            return '#FFFFFF'
        normalized_activation = activation / max_act*0.6
        return plt.cm.Reds(normalized_activation)
    
    styled_text = ''.join(f'<span style="background-color: {matplotlib.colors.rgb2hex(activation_to_color(activation))}; margin-right: 0px;">{token}</span>'
                          for token, activation in tokens_with_activations)
    return styled_text

print("feature id: ", feature_id)
print()
for i, snippet in enumerate(act_data['snippets']):
    styled_text = style_snippet(i)
    snippet_info = f'<div style="word-wrap: break-word; margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
    display(HTML(snippet_info))

feature id:  0

