In [1]:
import torch
import transformer_lens
from transformer_lens import HookedTransformer

from IPython.display import display
import circuitsvis as cv

from model import SparseAutoencoder
from config import SAEConfig

from utils import imshow

import plotly.express as px



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cpu' # keep everything on cpu for now
checkpoints_path = "/Users/slava/fun/pos_sae/converted_checkpoints" # TODO: move checkpoints to model hub.

In [3]:
# load model
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

# Load the SAEs
saes = [] # one for each layer
for layer in range(model.cfg.n_layers):
    path = f"{checkpoints_path}/final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576"
    sae = SparseAutoencoder.load_from_pretrained(path, silent=True)
    saes.append(sae)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
### For use with activation patching ###
# text = "the team traveled by" # bus
# text = "the team succeeded by" # working together

# text = "the cricket team traveled predominantly by" # bus
# text = "the cricket team succeeded predominantly by" # working together
###

In [5]:
text = "They raised awareness for the cause by"

tokens = model.tokenizer.encode(text, return_tensors="pt")
tokens = torch.cat([torch.tensor([[model.tokenizer.bos_token_id]]), tokens], dim=1) # prepend bos

logits, cache = model.run_with_cache(tokens)

In [6]:
# display(cv.attention.attention_patterns(
#     tokens=model.to_str_tokens(tokens),
#     attention=cache['pattern', 11][0],
#     attention_head_names=[f"L0H{i}" for i in range(12)],
# ))

In [7]:
def get_feature_activations(layer, id):
    target_f_sae = saes[layer]
    _, feature_acts, _, _, _ = target_f_sae(cache['resid_post', layer][0])
    target_f_acts = feature_acts[:, id]
    return target_f_acts

target_feature_layer = 11
target_feature_id = 23531

get_feature_activations(target_feature_layer, target_feature_id)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 70.1081],
       grad_fn=<SelectBackward0>)

In [8]:
def get_grads(layer, id, tokens):
    grads = {k: None for k in range(model.cfg.n_layers)}
    resid_cache = []

    target_f_sae = saes[layer]
    target_f_sae.zero_grad()
    model.zero_grad()

    def back_hook(input, hook):
        grads[hook.layer()] = input
    
    def c_hook(input, hook):
        resid_cache.append(input)
    
    bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
    cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]

    with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
        logits = model(tokens)
        resid_at_final = resid_cache[-1][0]
        _, feature_acts, _, _, _ = saes[layer](resid_at_final)
        target_f_acts = feature_acts[:, id]
        target_f_acts[-1].backward()
    
    return grads

grad_dict = get_grads(target_feature_layer, target_feature_id, tokens)

    

In [9]:
### Get the gradient of the target feature wrt residual stream. ###

# resid_cache = []
# grad_dict = {k: None for k in range(model.cfg.n_layers)}

# def back_hook(input, hook):
#     # print(hook.layer())
#     grad_dict[hook.layer()] = input

# def c_hook(input, hook):
#     resid_cache.append(input)

# bwd_hooks = [(f"blocks.{i}.hook_resid_post", back_hook) for i in range(model.cfg.n_layers)]
# cache_hooks = [(f"blocks.{i}.hook_resid_post", c_hook) for i in range(model.cfg.n_layers)]


# with model.hooks(fwd_hooks=cache_hooks, bwd_hooks=bwd_hooks):
#     logits = model(tokens)
#     resid_at_final = resid_cache[-1][0]
#     _, feature_acts, _, _, _ = saes[target_feature_layer](resid_at_final)
#     target_f_acts = feature_acts[:, target_feature_id]
#     target_f_acts[-1].backward()


In [10]:
def plot_grads():
    all_grads = []
    for k, v in grad_dict.items():
        if v is not None:
            all_grads.append(v[0])


    all_grads = torch.stack(all_grads)
    grad_heatmap = all_grads.abs().sum(dim=-1)

    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
    imshow(grad_heatmap, labels={"x": "Position", "y": "Layer"},x=labels, width=800, height=600)


plot_grads()

In [25]:
@torch.no_grad()
def attribution_for_position(layer, position, cache, grad_dict):
    sae = saes[layer]
    resid_grad = grad_dict[layer][0][position] # d_model vector
    grads_along_feats = sae.W_dec @ resid_grad # 24576 dim vector

    # compute attribution by multiplying grads_along_feats by feature activations.
    _, feature_acts, _, _, _ = saes[layer](cache['resid_post', layer][0])
    feature_acts = feature_acts[position] # 24576 dim vector
    attribution = grads_along_feats * feature_acts
    return attribution # 24576 dim vector

# attribution = attribution_for_position(10, 7, cache, grad_dict)

# top = attribution.argsort(descending=True)[:5]
# print('pos attributions')
# print(top)
# print(attribution[top])
# print()

# min = attribution.argsort()[:5]
# print('neg attributions')
# print(min)
# print(attribution[min])


In [26]:
def max_attribution_for_layer(layer, min=False, k=5):
    max_attributions = []
    top_features = []
    n_toks = cache['embed'].shape[1]
    for position in range(n_toks):
        attribution = attribution_for_position(layer, position, cache=cache, grad_dict=grad_dict)

        if min:
            top_idx = attribution.argsort()[:k]
        else:
            top_idx = attribution.argsort(descending=True)[:k]

        top_attrib = attribution[top_idx]

        max_attributions.append(top_attrib)
        top_features.append(top_idx)

    return max_attributions, top_features


def visualize_max_attributions(min=False, k=1):
    ### most positive attributions ###
    max_attributions_by_layer = []
    top_features_by_layer = []
    for layer in range(model.cfg.n_layers - 1):
        max_attributions, top_feature_indices = max_attribution_for_layer(layer, min=min, k=k)
        max_attributions = [attrib.sum() for attrib in max_attributions]
        top_feature_indices = [idx[0].item() for idx in top_feature_indices]
        # break
        max_attributions_by_layer.append(max_attributions)
        top_features_by_layer.append(top_feature_indices)

    max_attributions_by_layer = torch.tensor(max_attributions_by_layer)
    top_features_by_layer = torch.tensor(top_features_by_layer)

    top_idxs_text = [[str(idx.item()) for idx in indices] for indices in top_features_by_layer]
    labels = [f"{tok}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]

    imshow(max_attributions_by_layer, labels={"x": "Position", "y": "Layer"}, x=labels, text=top_idxs_text, width=800, height=600)


visualize_max_attributions()

In [27]:
# ### most negative attributions ###
visualize_max_attributions(min=True)

In [14]:
# get_feature_activations(11, 23531)
get_feature_activations(0, 2996)

tensor([10.6147, 19.4824, 18.4511,  9.3586, 11.7011, 12.1088, 10.9227, 11.2819],
       grad_fn=<SelectBackward0>)

In [15]:
# we will attempt to use Neel's tokenised dataset. This is not the same as the training distribution, but hopefully it doesn't matter.
