In [1]:
import torch
import transformer_lens
from transformer_lens import HookedTransformer

from IPython.display import display
import circuitsvis as cv

from model import SparseAutoencoder
from config import SAEConfig

from utils import imshow

import plotly.express as px



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cpu' # keep everything on cpu for now
checkpoints_path = "/Users/slava/fun/pos_sae/converted_checkpoints" # TODO: move checkpoints to model hub.

In [3]:
# load model
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

# Load the SAEs
saes = [] # one for each layer
for layer in range(model.cfg.n_layers):
    path = f"{checkpoints_path}/final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576"
    sae = SparseAutoencoder.load_from_pretrained(path, silent=True)
    saes.append(sae)

Loaded pretrained model gpt2-small into HookedTransformer


In [17]:
### For use with activation patching ###
# text = "the team traveled by" # bus
# text = "the team succeeded by" # working together

# text = "the cricket team traveled predominantly by" # bus
text = "the cricket team succeeded predominantly by" # working together
###

In [18]:
# text = "They raised awareness for the cause by"

tokens = model.tokenizer.encode(text, return_tensors="pt")
tokens = torch.cat([torch.tensor([[model.tokenizer.bos_token_id]]), tokens], dim=1) # prepend bos

logits, cache = model.run_with_cache(tokens)

In [19]:
# display(cv.attention.attention_patterns(
#     tokens=model.to_str_tokens(tokens),
#     attention=cache['pattern', 11][0],
#     attention_head_names=[f"L0H{i}" for i in range(12)],
# ))

In [20]:
# we do a backward pass.
# then, for every feature in a layer, we compute the gradient wrt the residual stream in the direction of that feature.
# (in the direction of the decoder vector of that feature)

In [21]:
# # backward pass on target feature activation on final position
# target_f_acts[-1].backward()
# cache['resid_post', 10][0].grad ## this is None. Hmmmm


In [22]:
target_feature_layer = 11
target_feature_id = 23531
target_f_sae = saes[target_feature_layer]

_, feature_acts, _, _, _ = target_f_sae(cache['resid_post', 11][0])
target_f_acts = feature_acts[:, target_feature_id]
print(target_f_acts)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 41.8160],
       grad_fn=<SelectBackward0>)


In [23]:
resid_cache = []
grad_dict = {k: None for k in range(model.cfg.n_layers)}

def back_hook(input, hook):
    # print(hook.layer())
    grad_dict[hook.layer()] = input

def c_hook(input, hook):
    resid_cache.append(input)

bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]


with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
    logits = model(tokens)
    print('len of resid cache:', len(resid_cache))
    resid_at_final = resid_cache[-1][0]
    print(resid_at_final.shape)
    _, feature_acts, _, _, _ = target_f_sae(resid_at_final)
    target_f_acts = feature_acts[:, target_feature_id]
    target_f_acts[-1].backward()



len of resid cache: 12
torch.Size([7, 768])


In [30]:
all_grads = []
for k, v in grad_dict.items():
    if v is not None:
        all_grads.append(v[0])


all_grads = torch.stack(all_grads)
grad_heatmap = all_grads.abs().sum(dim=-1)

labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(grad_heatmap, labels={"x": "Position", "y": "Layer"},x=labels, width=800, height=600)


In [25]:
# # now we compute the attribution for each active feature in each layer.

# # compute active features in each layer
# active_features: list[torch.Tensor] = []
# for layer in range(model.cfg.n_layers):
#     _, feature_acts, _, _, _ = saes[layer](cache['resid_post', layer][0])
#     # feature_acts is (n_toks x n_features)
#     sums = feature_acts.sum(dim=0) # features are non-negative.
#     active = torch.nonzero(sums > 0).squeeze(dim=-1)
#     active_features.append(active)


In [26]:
# for a given layer, there are 24576 features, and n_toks positions.

# so for a given (layer, position), this is how we compute the 24576 dim attribution vector:
# we have a d_model vector of gradients wrt the residual stream.
# we have a 24576 x d_model matrix of decoder vectors.

In [27]:
def attribution_for_position(layer, position):
    sae = saes[layer]
    resid_grad = grad_dict[layer][0][position] # d_model vector
    grads_along_feats = sae.W_dec @ resid_grad # 24576 dim vector

    # compute attribution by multiplying grads_along_feats by feature activations.
    _, feature_acts, _, _, _ = saes[layer](cache['resid_post', layer][0])
    feature_acts = feature_acts[position] # 24576 dim vector
    attribution = grads_along_feats * feature_acts
    return attribution

# attribution = attribution_for_position(10, 7)

# top = attribution.argsort(descending=True)[:5]
# print('pos attributions')
# print(top)
# print(attribution[top])
# print()

# min = attribution.argsort()[:5]
# print('neg attributions')
# print(min)
# print(attribution[min])


In [28]:
def max_attribution_for_layer(layer):
    max_attributions = []
    top_features = []
    for position in range(cache['resid_post', layer][0].shape[0]):
        attribution = attribution_for_position(layer, position)

        top_idx = attribution.argsort(descending=True)[0]
        top_attrib = attribution[top_idx]

        max_attributions.append(top_attrib)
        top_features.append(top_idx)

    return max_attributions, top_features


max_attributions_by_layer = []
top_features_by_layer = []
for layer in range(model.cfg.n_layers - 1):
    max_attributions, top_feature_indices = max_attribution_for_layer(layer)
    max_attributions_by_layer.append(max_attributions)
    top_features_by_layer.append(top_feature_indices)

max_attributions_by_layer = torch.tensor(max_attributions_by_layer)
top_features_by_layer = torch.tensor(top_features_by_layer)

top_idxs_text = [[str(idx.item()) for idx in indices] for indices in top_features_by_layer]
labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(max_attributions_by_layer, labels={"x": "Position", "y": "Layer"},x=labels, text=top_idxs_text, width=800, height=600)