In [1]:
import torch
import transformer_lens
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate

from datasets import load_dataset
from torch.utils.data import DataLoader

from IPython.display import display
import circuitsvis as cv

from model import SparseAutoencoder
from config import SAEConfig

from utils import imshow

import plotly.express as px



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cpu' # keep everything on cpu for now
checkpoints_path = "/Users/slava/fun/pos_sae/converted_checkpoints" # TODO: move checkpoints to model hub.

In [27]:
TARGET_FEATURE_LAYER = 11
TARGET_FEATURE_ID = 23531
ACTIVATION_THRESHOLD = 10.0

In [3]:
# load model
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

# Load the SAEs
saes = [] # one for each layer
for layer in range(model.cfg.n_layers):
    path = f"{checkpoints_path}/final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576"
    sae = SparseAutoencoder.load_from_pretrained(path, silent=True)
    saes.append(sae)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
### For use with activation patching ###
# text = "the team traveled by" # bus
# text = "the team succeeded by" # working together

# text = "the cricket team traveled predominantly by" # bus
# text = "the cricket team succeeded predominantly by" # working together
###

In [5]:
text = "They raised awareness for the cause by"

tokens = model.tokenizer.encode(text, return_tensors="pt")
tokens = torch.cat([torch.tensor([[model.tokenizer.bos_token_id]]), tokens], dim=1) # prepend bos

logits, cache = model.run_with_cache(tokens)

In [6]:
# display(cv.attention.attention_patterns(
#     tokens=model.to_str_tokens(tokens),
#     attention=cache['pattern', 11][0],
#     attention_head_names=[f"L0H{i}" for i in range(12)],
# ))

In [7]:
@torch.no_grad()
def get_feature_activations(cache, layer, id):
    target_f_sae = saes[layer]
    _, feature_acts, _, _, _ = target_f_sae(cache['resid_post', layer][0])
    target_f_acts = feature_acts[:, id]
    return target_f_acts



get_feature_activations(cache, TARGET_FEATURE_LAYER, TARGET_FEATURE_ID)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 70.1081])

In [8]:
def get_grads(layer, id, tokens):
    grads = {k: None for k in range(model.cfg.n_layers)}
    resid_cache = []

    target_f_sae = saes[layer]
    target_f_sae.zero_grad()
    model.zero_grad()

    def back_hook(input, hook):
        grads[hook.layer()] = input
    
    def c_hook(input, hook):
        resid_cache.append(input)
    
    bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
    cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]

    with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
        logits = model(tokens)
        resid_at_final = resid_cache[-1][0]
        _, feature_acts, _, _, _ = saes[layer](resid_at_final)
        target_f_acts = feature_acts[:, id]
        target_f_acts[-1].backward()
    
    return grads

grad_dict = get_grads(TARGET_FEATURE_LAYER, TARGET_FEATURE_ID, tokens)

    

In [9]:
### Get the gradient of the target feature wrt residual stream. ###

# resid_cache = []
# grad_dict = {k: None for k in range(model.cfg.n_layers)}

# def back_hook(input, hook):
#     # print(hook.layer())
#     grad_dict[hook.layer()] = input

# def c_hook(input, hook):
#     resid_cache.append(input)

# bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
# cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]


# with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
#     logits = model(tokens)
#     resid_at_final = resid_cache[-1][0]
#     _, feature_acts, _, _, _ = saes[TARGET_FEATURE_LAYER](resid_at_final)
#     target_f_acts = feature_acts[:, TARGET_FEATURE_ID]
#     target_f_acts[-1].backward()


In [10]:
def plot_grads():
    all_grads = []
    for k, v in grad_dict.items():
        if v is not None:
            all_grads.append(v[0])


    all_grads = torch.stack(all_grads)
    grad_heatmap = all_grads.abs().sum(dim=-1)

    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
    imshow(grad_heatmap, labels={"x": "Position", "y": "Layer"},x=labels, width=800, height=600)


plot_grads()

In [11]:
@torch.no_grad()
def attribution_for_position(layer, position, cache, grad_dict):
    sae = saes[layer]
    resid_grad = grad_dict[layer][0][position] # d_model vector
    grads_along_feats = sae.W_dec @ resid_grad # 24576 dim vector

    # compute attribution by multiplying grads_along_feats by feature activations.
    _, feature_acts, _, _, _ = saes[layer](cache['resid_post', layer][0])
    feature_acts = feature_acts[position] # 24576 dim vector
    attribution = grads_along_feats * feature_acts
    return attribution # 24576 dim vector

# attribution = attribution_for_position(10, 7, cache, grad_dict)

# top = attribution.argsort(descending=True)[:5]
# print('pos attributions')
# print(top)
# print(attribution[top])
# print()

# min = attribution.argsort()[:5]
# print('neg attributions')
# print(min)
# print(attribution[min])


In [12]:
def max_attribution_for_layer(layer, min=False, k=5):
    max_attributions = []
    top_features = []
    n_toks = cache['embed'].shape[1]
    for position in range(n_toks):
        attribution = attribution_for_position(layer, position, cache=cache, grad_dict=grad_dict)

        if min:
            top_idx = attribution.argsort()[:k]
        else:
            top_idx = attribution.argsort(descending=True)[:k]

        top_attrib = attribution[top_idx]

        max_attributions.append(top_attrib)
        top_features.append(top_idx)

    return max_attributions, top_features


def visualize_max_attributions(min=False, k=1):
    ### most positive attributions ###
    max_attributions_by_layer = []
    top_features_by_layer = []
    for layer in range(model.cfg.n_layers - 1):
        max_attributions, top_feature_indices = max_attribution_for_layer(layer, min=min, k=k)
        max_attributions = [attrib.sum() for attrib in max_attributions]
        top_feature_indices = [idx[0].item() for idx in top_feature_indices]
        # break
        max_attributions_by_layer.append(max_attributions)
        top_features_by_layer.append(top_feature_indices)

    max_attributions_by_layer = torch.tensor(max_attributions_by_layer)
    top_features_by_layer = torch.tensor(top_features_by_layer)

    top_idxs_text = [[str(idx.item()) for idx in indices] for indices in top_features_by_layer]
    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]

    imshow(max_attributions_by_layer, labels={"x": "Position", "y": "Layer"}, x=labels, text=top_idxs_text, width=800, height=600)


visualize_max_attributions()

In [13]:
# ### most negative attributions ###
visualize_max_attributions(min=True)

In [14]:
# get_feature_activations(cache, 11, 23531)
get_feature_activations(cache, 0, 2996)

tensor([10.6147, 19.4824, 18.4511,  9.3586, 11.7011, 12.1088, 10.9227, 11.2819])

In [15]:
# we will attempt to use Neel's tokenised dataset. This is not the same as the training distribution, but hopefully it doesn't matter.


In [33]:

all_attributions = dict() # {feature_id: list of attributions}

def update_attributions(tokens):
    logits, cache = model.run_with_cache(tokens) # inefficient, but I don't care.
    grad_dict = get_grads(TARGET_FEATURE_LAYER, TARGET_FEATURE_ID, tokens)
    scale_cutoff = 1.0

    for layer in range(model.cfg.n_layers - 1):
        for position in range(tokens.shape[1]):
            attribution = attribution_for_position(layer, position, cache=cache, grad_dict=grad_dict)

            min_idx = attribution.argsort()[:5]
            max_idx = attribution.argsort(descending=True)[:5]
            min_attrib = attribution[min_idx]
            max_attrib = attribution[max_idx]

            for i, val in zip(min_idx, min_attrib):
                ft_id = f"{layer}_{i.item()}"
                if val < -scale_cutoff:
                    all_attributions.setdefault(ft_id, []).append(val.item())
            
            for i, val in zip(max_idx, max_attrib):
                ft_id = f"{layer}_{i.item()}"
                if val > scale_cutoff:
                    all_attributions.setdefault(ft_id, []).append(val.item())


In [20]:
data = load_dataset("NeelNanda/pile-10k", split="train")
tokenized = tokenize_and_concatenate(data, model.tokenizer, max_length=32)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [34]:
all_attributions = dict() # {feature_id: list of attributions}

for i, batch in enumerate(tokenized):
    tokens = batch["tokens"]
    tokens = tokens.unsqueeze(0)
    logits, cache = model.run_with_cache(tokens)

    acts = get_feature_activations(cache, TARGET_FEATURE_LAYER, TARGET_FEATURE_ID)
    active_positions = (acts > 0.0).nonzero()

    for pos in active_positions:
        activation = acts[pos]
        if activation > ACTIVATION_THRESHOLD:  # a hack to remove low activations
            print('activation', acts[pos].item())
            print(model.to_str_tokens(tokens[:, :pos+1]))
            update_attributions(tokens[:, :pos+1])

    if i > 100:
        break

activation 10.066123008728027
['<|endoftext|>', ' of', ' foe', ' trying', ' to', ' stop', ' them', ' from']
activation 11.956418991088867
['<|endoftext|>', 'org', ',', ' or', ' fighting', ' the', ' Borg', '?', '\n', '\n', 'The', ' third', ' and', ' final', ' idea', ' came', ' to', ' me', ' through', ' my', ' girlfriend', ',', ' who', ' somehow', ' gave', ' me', ' the', ' idea', ' of']
activation 56.776222229003906
['<|endoftext|>', ' A', ' secondary', ' idea', ' here', ' was', ' that', ' the', ' game', ' would', ' work', ' to', ' explain', ' how', ' the', ' Flying', ' Sp', 'aghetti', ' Monster', ' came', ' to', ' exist', ' –', ' by']
activation 38.34526443481445
['<|endoftext|>', '.', ' There', ' are', ' 5', ' other', ' guests', ' at', ' the', ' table', ',', ' each', ' with', ' their', ' own', ' plate', '.', '\n', '\n', 'Your', ' plate', ' can', ' spawn', ' little', ' pieces', ' of', ' pasta', '.', ' You', ' do', ' so', ' by']
activation 10.194314002990723
['<|endoftext|>', ' with', '.

In [29]:
all_attributions.keys()

dict_keys([7829, 14199, 4384, 16175, 20088, 15648, 21080, 1451, 8241, 7268, 18295, 5732, 22892, 8148, 21622, 18764, 277, 21025, 18710, 10325, 9890, 8414, 16473, 9043, 3945, 789, 24149, 21958, 17773, 519, 19738, 641, 22415, 9787, 13379, 15531, 502, 8966, 5468, 16725, 13293, 11910, 6340, 1697, 13280, 9715, 13661, 8274, 18150, 8521, 11876, 8614, 20755, 16703, 7043, 22244, 2440, 18578, 20976, 7006, 19299, 12760, 11379, 9133, 20326, 6978, 11804, 9289, 2762, 22759, 5203, 4050, 18044, 11953, 657, 11234, 10408, 18593, 14392, 10696, 828, 6985, 4596, 4893, 8724, 2146, 15757, 14468, 24096, 2815, 13796, 17252, 16815, 6183, 16920, 15503, 9960, 12878, 15536, 7375, 12871, 5584, 21000, 17043, 20236, 7620, 3217, 6574, 23392, 9426, 14822, 10499, 8326, 5742, 8865, 22843, 23347, 22877, 19182, 14374, 6725, 18918, 24553, 10328, 5457, 21374, 344, 5276, 7365, 6923, 17840, 4907, 13013, 5363, 15264, 6051, 13283, 1836, 339, 20135, 3357, 13939, 13974, 20998, 8812, 9769, 9293, 20133, 12144, 17688, 5910, 5085, 9510

In [35]:
def get_top_pairs(attribution_dict, top_n=5):
    sorted_pairs = sorted(attribution_dict.items(), key=lambda x: sum(map(abs, x[1])), reverse=True)
    return sorted_pairs[:top_n]

top_pairs = get_top_pairs(all_attributions, top_n=10)

In [36]:
top_pairs

[('10_9960',
  [11.45576000213623,
   13.393526077270508,
   43.786319732666016,
   38.045379638671875,
   1.035262107849121,
   10.815940856933594,
   1.083121418952942,
   16.694101333618164]),
 ('2_19738',
  [-14.797445297241211,
   -8.890697479248047,
   -11.655292510986328,
   -15.297065734863281,
   9.61153507232666,
   11.995420455932617,
   -10.891196250915527,
   -33.62974166870117]),
 ('9_24096',
  [8.310940742492676,
   30.75067138671875,
   33.67822265625,
   11.906428337097168,
   10.26396369934082]),
 ('8_19235',
  [30.46546745300293,
   28.063631057739258,
   14.582470893859863,
   15.583444595336914]),
 ('2_15531',
  [13.86876106262207,
   17.999923706054688,
   16.448911666870117,
   25.240646362304688]),
 ('1_277',
  [-11.019274711608887,
   -10.985179901123047,
   -15.131011009216309,
   12.89258861541748,
   -22.12629508972168]),
 ('2_8750', [14.311951637268066, 31.45849609375, 18.344303131103516]),
 ('3_4271',
  [17.298139572143555,
   9.5475435256958,
   18.202781