In [99]:
from transformer_lens import HookedTransformer
import torch as t
import einops
from transformer_lens import utils as tutils

from functools import partial
import random
from tqdm import tqdm


from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

In [176]:
import plotly.graph_objects as go

def display_lists_as_table(*lists, names=[], title=""):
    # Check if all lists have the same length
    list_lengths = [len(lst) for lst in lists]
    if len(set(list_lengths)) != 1:
        raise ValueError("All lists must have the same length.")

    # Create the header values
    header_values = [f'Column {i+1}' for i in range(len(lists[0]))]

    # Create the data for the table
    data = [go.Table(
        header=dict(values=names),
        cells=dict(values=lists),
    )]

    # Create the layout
    layout = go.Layout(
        title=title
    )

    # Create the figure and display the table
    fig = go.Figure(data=data, layout=layout)
    fig.show()

# Example usage
list1 = [1, 2, 3, 4, 5]
list2 = [6, 7, 8, 9, 10]
list3 = [11, 12, 13, 14, 15]

display_lists_as_table(list1, list2, names=["Token", "Value"])
display_lists_as_table(list1, list2, list3, names=["Token", "Value", "extra"])

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)



Loaded pretrained model gpt2-small into HookedTransformer


In [59]:
layer = 7 # pick a layer you want.

hook_point = tutils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_point)

print(saes.keys())
sae = saes[hook_point]
sae = sae.to(model.W_E.device)

  0%|          | 0/1 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


blocks.7.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

100%|██████████| 1/1 [00:13<00:00, 14.00s/it]

dict_keys(['blocks.7.hook_resid_pre'])





In [170]:
def top_acts_at_pos(text, pos=-1, silent=True, prepend_bos=True, n_top=10):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sae(hidden_state).feature_acts
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = t.topk(feature_acts, n_top)
    return top_v, top_i

top_activations = top_acts_at_pos("This day was like", pos=-1)
print(top_activations)

print(top_activations[0][0].item())
print(top_activations[1][0].item())

metaphor_vector = sae.W_dec[top_activations[1][0].item(), :]
metaphor_activation = top_activations[0][0].item()

print(metaphor_vector)
print(metaphor_activation)

(tensor([34.4923, 13.2345, 11.3497,  9.1794,  6.1575,  4.6734,  4.3486,  4.2683,
         3.7980,  3.7679], device='mps:0', grad_fn=<TopkBackward0>), tensor([14093,  6430, 22331, 19144, 24450, 23221,  6243,  3102,  4745,  5380],
       device='mps:0'))
34.49229049682617
14093
tensor([-4.9949e-03, -1.1252e-02, -2.5019e-02, -1.7989e-02, -2.3345e-02,
        -1.7728e-02,  1.4336e-02, -4.0011e-02,  2.1098e-02,  2.7976e-02,
         1.4074e-02, -1.2264e-02, -2.1630e-02, -9.8924e-03,  7.5899e-02,
        -9.5182e-02, -1.5417e-02,  1.3840e-02, -2.4799e-02,  1.8506e-02,
        -5.5362e-02, -3.7139e-02,  2.3834e-02, -5.1518e-02, -1.7197e-02,
         4.9247e-02,  2.6180e-02, -2.4630e-03,  1.5591e-02, -8.1095e-03,
        -1.6908e-02, -4.0562e-02,  5.1226e-02, -2.3033e-02, -1.0005e-02,
         2.0077e-02, -3.8991e-02,  1.4347e-02,  4.8115e-03, -5.5510e-02,
        -3.3198e-02,  1.9290e-02, -6.7268e-03,  2.3226e-02, -4.9797e-02,
        -4.7200e-03,  6.8628e-02,  2.3221e-02, -6.3930e-02,  5.595

In [177]:
def patch_hook(
    resid,
    hook,
    c,
    steering_vector,
    activation_norm
):    
    print(resid.shape)
    resid[:, 3, :] += c * activation_norm * steering_vector
    return resid

target_tokens = [" like", " a"]
target_vocab_positions = model.tokenizer.encode(target_tokens)

print(target_vocab_positions)

# logits, cache = model.run_with_cache("It's beautiful here, like")
for i in tqdm(range(3)):
    text = "Seeing this view was"

    print(model.tokenizer.tokenize(text))

    n_samples = 15

    new_samples = []

    for i in range(n_samples):
        if i < 10:
            logits = model.run_with_hooks(text, fwd_hooks=[(
                "blocks.5.hook_resid_pre",
                partial(patch_hook, c=40, steering_vector=metaphor_vector, activation_norm=metaphor_activation)
            )])
        else:
            logits, cache = model.run_with_cache(text)

        # return index positions for all vocab elements ex: [2339, 64]
        indices = logits[0,-1].topk(logits[0,-1].shape[0]).indices.tolist()
        ranks = [indices.index(v) for v in target_vocab_positions]

        # print(model.tokenizer.batch_decode(logits[0,-1].topk(20).indices))

        top_logits = logits[0,-1].topk(20)
        display_lists_as_table(list(range(1, 21)), model.tokenizer.batch_decode(top_logits.indices), top_logits.values.tolist(), names=["Logit Position", "Token", "Value"], title=text + "...")

        for v, r in zip(target_tokens, ranks):
            print(f"{v}_rank: {r}")
        
        # print(logits[0][-1].topk(3))
        # text += model.tokenizer.decode(logits[0][-1].topk(1).indices[0])
        text += model.tokenizer.decode(random.choice(logits[0][-1].topk(11).indices))
        new_samples.append(text)
        print(text)
    print(text)
    

[588, 257]


  0%|          | 0/3 [00:00<?, ?it/s]

['Seeing', 'Ġthis', 'Ġview', 'Ġwas']
torch.Size([1, 5, 768])


 like_rank: 1
 a_rank: 7
Seeing this view was when
torch.Size([1, 6, 768])


 like_rank: 1604
 a_rank: 2
Seeing this view was when it
torch.Size([1, 7, 768])


 like_rank: 826
 a_rank: 200
Seeing this view was when it hit
torch.Size([1, 8, 768])


 like_rank: 79
 a_rank: 5
Seeing this view was when it hit:
torch.Size([1, 9, 768])


 like_rank: 277
 a_rank: 6
Seeing this view was when it hit: "
torch.Size([1, 10, 768])


 like_rank: 5479
 a_rank: 913
Seeing this view was when it hit: "what
torch.Size([1, 11, 768])


 like_rank: 670
 a_rank: 2
Seeing this view was when it hit: "what you
torch.Size([1, 12, 768])


 like_rank: 269
 a_rank: 51
Seeing this view was when it hit: "what you were
torch.Size([1, 13, 768])


 like_rank: 18
 a_rank: 36
Seeing this view was when it hit: "what you were saying
torch.Size([1, 14, 768])


 like_rank: 216
 a_rank: 19
Seeing this view was when it hit: "what you were saying,


 like_rank: 79
 a_rank: 27
Seeing this view was when it hit: "what you were saying, this


 like_rank: 379
 a_rank: 122
Seeing this view was when it hit: "what you were saying, this is


 like_rank: 16
 a_rank: 2
Seeing this view was when it hit: "what you were saying, this is not


 like_rank: 17
 a_rank: 2
Seeing this view was when it hit: "what you were saying, this is not what


 33%|███▎      | 1/3 [00:13<00:26, 13.46s/it]

 like_rank: 706
 a_rank: 12
Seeing this view was when it hit: "what you were saying, this is not what it
Seeing this view was when it hit: "what you were saying, this is not what it
['Seeing', 'Ġthis', 'Ġview', 'Ġwas']
torch.Size([1, 5, 768])


 like_rank: 1
 a_rank: 7
Seeing this view was a
torch.Size([1, 6, 768])


 like_rank: 2776
 a_rank: 1425
Seeing this view was a nightmare
torch.Size([1, 7, 768])


 like_rank: 103
 a_rank: 82
Seeing this view was a nightmare;
torch.Size([1, 8, 768])


 like_rank: 20
 a_rank: 2
Seeing this view was a nightmare; a
torch.Size([1, 9, 768])


 like_rank: 3494
 a_rank: 1775
Seeing this view was a nightmare; a nightmare
torch.Size([1, 10, 768])


 like_rank: 58
 a_rank: 71
Seeing this view was a nightmare; a nightmare;
torch.Size([1, 11, 768])


 like_rank: 31
 a_rank: 0
Seeing this view was a nightmare; a nightmare; you
torch.Size([1, 12, 768])


 like_rank: 245
 a_rank: 281
Seeing this view was a nightmare; a nightmare; you are
torch.Size([1, 13, 768])


 like_rank: 2
 a_rank: 5
Seeing this view was a nightmare; a nightmare; you are so
torch.Size([1, 14, 768])


 like_rank: 336
 a_rank: 241
Seeing this view was a nightmare; a nightmare; you are so far


 like_rank: 137
 a_rank: 29
Seeing this view was a nightmare; a nightmare; you are so far removed


 like_rank: 72
 a_rank: 55
Seeing this view was a nightmare; a nightmare; you are so far removed by


 like_rank: 1095
 a_rank: 9
Seeing this view was a nightmare; a nightmare; you are so far removed by the


 like_rank: 3517
 a_rank: 1041
Seeing this view was a nightmare; a nightmare; you are so far removed by the light


 67%|██████▋   | 2/3 [00:24<00:12, 12.25s/it]

 like_rank: 59
 a_rank: 69
Seeing this view was a nightmare; a nightmare; you are so far removed by the light.
Seeing this view was a nightmare; a nightmare; you are so far removed by the light.
['Seeing', 'Ġthis', 'Ġview', 'Ġwas']
torch.Size([1, 5, 768])


 like_rank: 1
 a_rank: 7
Seeing this view was just
torch.Size([1, 6, 768])


 like_rank: 23
 a_rank: 0
Seeing this view was just beginning
torch.Size([1, 7, 768])


 like_rank: 130
 a_rank: 6
Seeing this view was just beginning to
torch.Size([1, 8, 768])


 like_rank: 948
 a_rank: 12
Seeing this view was just beginning to come
torch.Size([1, 9, 768])


 like_rank: 94
 a_rank: 60
Seeing this view was just beginning to come out
torch.Size([1, 10, 768])


 like_rank: 80
 a_rank: 42
Seeing this view was just beginning to come out for
torch.Size([1, 11, 768])


 like_rank: 1685
 a_rank: 2
Seeing this view was just beginning to come out for what
torch.Size([1, 12, 768])


 like_rank: 464
 a_rank: 24
Seeing this view was just beginning to come out for what might
torch.Size([1, 13, 768])


 like_rank: 59
 a_rank: 37
Seeing this view was just beginning to come out for what might come
torch.Size([1, 14, 768])


 like_rank: 58
 a_rank: 26
Seeing this view was just beginning to come out for what might come to


 like_rank: 30
 a_rank: 2
Seeing this view was just beginning to come out for what might come to define


 like_rank: 1215
 a_rank: 1
Seeing this view was just beginning to come out for what might come to define the


 like_rank: 4072
 a_rank: 509
Seeing this view was just beginning to come out for what might come to define the future


 like_rank: 73
 a_rank: 75
Seeing this view was just beginning to come out for what might come to define the future of


100%|██████████| 3/3 [00:36<00:00, 12.28s/it]

 like_rank: 6356
 a_rank: 4
Seeing this view was just beginning to come out for what might come to define the future of gaming
Seeing this view was just beginning to come out for what might come to define the future of gaming





In [25]:
anger = model.tokenizer.encode("anger")
you = model.tokenizer.encode(" you")

# get rank of the word "anger" in the vocab
anger_rank = logits[0,-1].topk(logits[0,-1].shape[0]).indices.tolist().index(anger[0])
you_rank = logits[0,-1].topk(logits[0,-1].shape[0]).indices.tolist().index(you[0])

print(anger_rank)
print(you_rank)

42166
0


In [102]:
like = model.tokenizer.encode("like")

like_rank = logits[0,-1].topk(logits[0,-1].shape[0]).indices.tolist().index(like[0])

print(like_rank)

20714


In [51]:
# project out the logits at every position
W_U = model.W_U

# cache residual post
for layer in range(model.cfg.n_layers):
    print(f"layer {layer}")
    residual_stream_layer = cache[f"blocks.{layer}.hook_resid_pre"]

    ## multiplying the current residual stream by W_U to get predicted tokens at that point
    ## this is the same as ablating all later layers and getting the logits
    unembedded_vocab = einops.einsum(
        residual_stream_layer,
        W_U,
        "b n d, d e -> b n e"
    )

    # print(unembedded_layer.shape)
    top_logits = unembedded_vocab[0,-1].topk(5)
    print(model.tokenizer.batch_decode(unembedded_vocab[0,-1].topk(5).indices))
print("final output")
print(model.tokenizer.batch_decode(logits[0,-1].topk(5).indices))

layer 0
[' destro', ' livest', ' challeng', ' mathemat', 'theless']
layer 1
['wolf', ' not', 'wolves', ' currently', ' still']
layer 2
[' currently', ' now', ' not', ' unlikely', 'wolf']
layer 3
[' supposed', ' currently', 'nt', 'wolf', 'wolves']
layer 4
['nt', ' supposed', ' currently', ' gonna', 'wolf']
layer 5
[' supposed', 'nt', ' gonna', ' suppose', ' currently']
layer 6
['nt', ' supposed', ' gonna', ' ya', ' suppose']
layer 7
['nt', ' gonna', ' supposed', ' ya', ' suppose']
layer 8
[' ya', ' gonna', 'nt', ' supposed', ' you']
layer 9
[' ya', ' guys', ' gonna', ' alot', 'nt']
layer 10
[' ya', ' you', ' guys', ' alot', ' things']
layer 11
[' you', ' ya', ' we', ' i', ' ye']
final output
[' you', ' i', ' we', ' your', ' u']


In [54]:
# I want to be able to track specific tokens
# I want to be able to see the top activations.

top_logit = logits[0,-1].topk(1).indices.item()
print(top_logit)


345


In [None]:
# Ok we want some kind of metric that we can use from this stuff.
# When we make an intervention, how does it change the downstream activations? The logits?
# And specifically, how does it change particular logits that we are interested in?

# We can do this by looking at the activations of the token we are interested in, and then looking at the activations of the token we are interested in after the intervention.
