In [7]:
import torch
import transformer_lens
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate

from datasets import load_dataset
from torch.utils.data import DataLoader

from IPython.display import display
import circuitsvis as cv
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm import tqdm

from model import SparseAutoencoder
from config import SAEConfig

from utils import imshow

import plotly.express as px
import plotly.graph_objects as go

import os


In [10]:
# device = 'cpu' # keep everything on cpu for now
device = 'mps'

os.makedirs("./checkpoints", exist_ok=True)
checkpoints_path = snapshot_download("schalnev/jbloom_SAE_reupload", local_dir="checkpoints")
# Note: these are Joseph Bloom's checkpoints, but I had previously downloaded and modified them to be easier to use.


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

(…)l_blocks.0.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)_blocks.10.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.16k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

(…)l_blocks.1.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.1.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)_blocks.10.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.0.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.2.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)_blocks.11.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.16k [00:00<?, ?B/s]

(…)_blocks.11.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.4.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.3.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.2.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.3.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.5.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.4.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.5.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.6.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.6.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.7.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.7.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.8.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.8.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)l_blocks.9.hook_resid_pre_24576/cfg.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

(…)l_blocks.9.hook_resid_pre_24576/model.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

In [11]:
TARGET_FEATURE_LAYER = 11
TARGET_FEATURE_ID = 23531
ACTIVATION_THRESHOLD = 10.0

In [14]:
# load model
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

# Load the SAEs
saes = [] # one for each layer
for layer in range(model.cfg.n_layers):
    path = f"{checkpoints_path}/final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576"
    sae = SparseAutoencoder.load_from_pretrained(path, silent=True)
    sae.to(device)
    saes.append(sae)

Loaded pretrained model gpt2-small into HookedTransformer


In [15]:
### For use with activation patching ###
# text = "the team traveled by" # bus
# text = "the team succeeded by" # working together

# text = "the cricket team traveled predominantly by" # bus
text = "the cricket team succeeded predominantly by" # working together
###

In [16]:
# text = "They raised awareness for the cause by"

tokens = model.tokenizer.encode(text, return_tensors="pt")
tokens = torch.cat([torch.tensor([[model.tokenizer.bos_token_id]]), tokens], dim=1) # prepend bos

logits, cache = model.run_with_cache(tokens)

In [17]:
# display(cv.attention.attention_patterns(
#     tokens=model.to_str_tokens(tokens),
#     attention=cache['pattern', 11][0],
#     attention_head_names=[f"L0H{i}" for i in range(12)],
# ))

In [18]:
@torch.no_grad()
def get_feature_activations(cache, layer, id):
    target_f_sae = saes[layer]
    _, feature_acts, _, _, _ = target_f_sae(cache['resid_post', layer][0])
    target_f_acts = feature_acts[:, id]
    return target_f_acts



get_feature_activations(cache, TARGET_FEATURE_LAYER, TARGET_FEATURE_ID)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 41.8160],
       device='mps:0')

In [19]:
def get_grads(layer, id, tokens):
    grads = {k: None for k in range(model.cfg.n_layers)}
    resid_cache = []

    target_f_sae = saes[layer]
    target_f_sae.zero_grad()
    model.zero_grad()

    def back_hook(input, hook):
        grads[hook.layer()] = input
    
    def c_hook(input, hook):
        resid_cache.append(input)
    
    bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
    cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]

    with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
        logits = model(tokens)
        resid_at_final = resid_cache[-1][0]
        _, feature_acts, _, _, _ = saes[layer](resid_at_final)
        target_f_acts = feature_acts[:, id]
        target_f_acts[-1].backward()
    
    return grads

grad_dict = get_grads(TARGET_FEATURE_LAYER, TARGET_FEATURE_ID, tokens)

    

In [20]:
### Get the gradient of the target feature wrt residual stream. ###

# resid_cache = []
# grad_dict = {k: None for k in range(model.cfg.n_layers)}

# def back_hook(input, hook):
#     # print(hook.layer())
#     grad_dict[hook.layer()] = input

# def c_hook(input, hook):
#     resid_cache.append(input)

# bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
# cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]


# with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
#     logits = model(tokens)
#     resid_at_final = resid_cache[-1][0]
#     _, feature_acts, _, _, _ = saes[TARGET_FEATURE_LAYER](resid_at_final)
#     target_f_acts = feature_acts[:, TARGET_FEATURE_ID]
#     target_f_acts[-1].backward()


In [21]:
def plot_grads():
    all_grads = []
    for k, v in grad_dict.items():
        if v is not None:
            all_grads.append(v[0])


    all_grads = torch.stack(all_grads)
    grad_heatmap = all_grads.abs().sum(dim=-1)

    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
    imshow(grad_heatmap, labels={"x": "Position", "y": "Layer"},x=labels, title="Gradients wrt Residual Stream", width=800, height=600)


plot_grads()

In [22]:
@torch.no_grad()
def attribution_for_position(layer, position, cache, grad_dict):
    sae = saes[layer]
    resid_grad = grad_dict[layer][0][position] # d_model vector
    grads_along_feats = sae.W_dec @ resid_grad # 24576 dim vector

    # compute attribution by multiplying grads_along_feats by feature activations.
    _, feature_acts, _, _, _ = saes[layer](cache['resid_post', layer][0])
    feature_acts = feature_acts[position] # 24576 dim vector
    attribution = grads_along_feats * feature_acts
    return attribution # 24576 dim vector

# attribution = attribution_for_position(10, 7, cache, grad_dict)

# top = attribution.argsort(descending=True)[:5]
# print('pos attributions')
# print(top)
# print(attribution[top])
# print()

# min = attribution.argsort()[:5]
# print('neg attributions')
# print(min)
# print(attribution[min])


In [23]:
def max_attribution_for_layer(layer, min=False, k=5):
    max_attributions = []
    top_features = []
    n_toks = cache['embed'].shape[1]
    for position in range(n_toks):
        attribution = attribution_for_position(layer, position, cache=cache, grad_dict=grad_dict)

        if min:
            top_idx = attribution.argsort()[:k]
        else:
            top_idx = attribution.argsort(descending=True)[:k]

        top_attrib = attribution[top_idx]

        max_attributions.append(top_attrib)
        top_features.append(top_idx)

    return max_attributions, top_features


def visualize_max_attributions(min=False, k=1):
    ### most positive attributions ###
    max_attributions_by_layer = []
    top_features_by_layer = []
    for layer in range(model.cfg.n_layers - 1):
        max_attributions, top_feature_indices = max_attribution_for_layer(layer, min=min, k=k)
        max_attributions = [attrib.sum() for attrib in max_attributions]
        top_feature_indices = [idx[0].item() for idx in top_feature_indices]
        # break
        max_attributions_by_layer.append(max_attributions)
        top_features_by_layer.append(top_feature_indices)

    max_attributions_by_layer = torch.tensor(max_attributions_by_layer)
    top_features_by_layer = torch.tensor(top_features_by_layer)

    top_idxs_text = [[str(idx.item()) for idx in indices] for indices in top_features_by_layer]
    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]

    title = "Most Positive Attributions" if not min else "Most Negative Attributions"

    imshow(max_attributions_by_layer, labels={"x": "Position", "y": "Layer"}, x=labels,title=title, text=top_idxs_text, width=800, height=600)


visualize_max_attributions()

In [24]:
# ### most negative attributions ###
visualize_max_attributions(min=True)

In [25]:
# get_feature_activations(cache, 11, 23531)
get_feature_activations(cache, 0, 2996)

tensor([10.6147, 17.9311, 10.8336,  9.9957, 19.8247, 15.1831, 12.3742],
       device='mps:0')

In [26]:
# we will attempt to use Neel's tokenised dataset. This is not the same as the training distribution, but hopefully it doesn't matter.


In [27]:

all_attributions = dict() # {feature_id: list of attributions}

def update_attributions(tokens):
    logits, cache = model.run_with_cache(tokens) # inefficient, but I don't care.
    grad_dict = get_grads(TARGET_FEATURE_LAYER, TARGET_FEATURE_ID, tokens)
    scale_cutoff = 1.0

    for layer in range(model.cfg.n_layers - 1):
        for position in range(tokens.shape[1]):
            attribution = attribution_for_position(layer, position, cache=cache, grad_dict=grad_dict)

            min_idx = attribution.argsort()[:5]
            max_idx = attribution.argsort(descending=True)[:5]
            min_attrib = attribution[min_idx]
            max_attrib = attribution[max_idx]

            for i, val in zip(min_idx, min_attrib):
                ft_id = f"{layer}_{i.item()}"
                if val < -scale_cutoff:
                    all_attributions.setdefault(ft_id, []).append(val.item())
            
            for i, val in zip(max_idx, max_attrib):
                ft_id = f"{layer}_{i.item()}"
                if val > scale_cutoff:
                    all_attributions.setdefault(ft_id, []).append(val.item())


In [28]:
data = load_dataset("NeelNanda/pile-10k", split="train")
tokenized = tokenize_and_concatenate(data, model.tokenizer, max_length=32)

In [None]:
all_attributions = dict() # {feature_id: list of attributions}
num_steps = 2000

for i, batch in enumerate(tqdm(tokenized, total=num_steps)):
    tokens = batch["tokens"]
    tokens = tokens.unsqueeze(0)
    logits, cache = model.run_with_cache(tokens)

    acts = get_feature_activations(cache, TARGET_FEATURE_LAYER, TARGET_FEATURE_ID)
    active_positions = (acts > 0.0).nonzero()

    for pos in active_positions:
        activation = acts[pos]
        if activation > ACTIVATION_THRESHOLD:  # a hack to remove low activations
            # print('activation', acts[pos].item())
            # print(model.to_str_tokens(tokens[:, :pos+1]))
            update_attributions(tokens[:, :pos+1])

    if i > num_steps:
        break

2001it [08:45,  3.81it/s]                          


In [None]:
def get_top_pairs(attribution_dict, top_n=5):
    sorted_pairs = sorted(attribution_dict.items(), key=lambda x: sum(map(abs, x[1])), reverse=True)
    return sorted_pairs[:top_n]

top_pairs = get_top_pairs(all_attributions, top_n=10)

In [None]:
# for k, v in top_pairs:
#     print(k, v)

In [None]:
## thanks Claude!

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create a subplot figure with len(top_pairs) rows and 1 column
fig = make_subplots(rows=len(top_pairs), cols=1, subplot_titles=[pair[0] for pair in top_pairs])

# Iterate over each element in top_pairs and its corresponding index
for i, (feature_id, attribution_values) in enumerate(top_pairs, start=1):
    # Create a histogram trace for each element
    trace = go.Histogram(
        x=attribution_values,
        name=feature_id,
        opacity=0.75
    )
    
    # Add the trace to the corresponding subplot
    fig.add_trace(trace, row=i, col=1)
    
    # Add a vertical line at x=0 for each subplot
    fig.add_vline(x=0, row=i, col=1, line_width=1, line_dash="dash", line_color="black")

# Update the layout
fig.update_layout(
    title='Histograms of Attribution Values',
    height=300 * len(top_pairs),  # Adjust the height based on the number of subplots
    showlegend=False
)

# Update x-axis range for each subplot
for i in range(1, len(top_pairs) + 1):
    fig.update_xaxes(range=[-40, 75], row=i, col=1)

# Display the figure
fig.show()