In [10]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fda55afceb0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [20]:
norms = [0, 10, 20, 40, 60, 80, 100, 120]

In [21]:
def losses_for_norms(norms, vector):
    normed_vector = vector / vector.norm()
    losses = get_loss(model, hp6, normed_vector, scales=norms, insertion_pos=None)
    return losses

In [22]:
anger_sae_vec = sae6.W_dec[1062] * 56  # anger
anger_sae_vec = anger_sae_vec[None, None, :]

anger_sae_losses = losses_for_norms(norms, anger_sae_vec)
anger_sae_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [00:59<00:00,  7.49s/it]


[2.037747383117676,
 2.0441823053359984,
 2.0720777678489686,
 2.2562749361991883,
 2.704720501899719,
 3.5302905654907226,
 4.782238540649414,
 6.302010145187378]

In [23]:
act_vec = get_activation_steering(model, hp6, pos_text="Anger", neg_text="Calm")[0, 1, :]
act_vec = act_vec[None, None, :]

act_diff_losses = losses_for_norms(norms, act_vec)
act_diff_losses

pad tensor([], device='cuda:0', size=(1, 0), dtype=torch.int64)
loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.0471405363082886,
 2.08109929561615,
 2.2730241632461547,
 2.8444892740249634,
 4.036533904075623,
 5.590423221588135,
 7.094773988723755]

In [25]:
sae_sum_vec = sae6.W_dec[1062] * 56 + sae6.W_dec[14586] * 96  # anger + blog post
sae_sum_vec = sae_sum_vec[None, None, :]

sae_sum_losses = losses_for_norms(norms, sae_sum_vec)
sae_sum_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.0423549485206602,
 2.0587640690803526,
 2.1463862109184264,
 2.3766694688796997,
 2.8935400915145872,
 3.811956706047058,
 4.950247449874878]

In [33]:
# random normal vector

randn_all_losses = []
for _ in range(10):
    randn_vec = torch.randn_like(anger_sae_vec)
    randn_losses = losses_for_norms(norms, randn_vec)
    randn_all_losses.append(randn_losses)
all_losses_tensor = torch.tensor(randn_all_losses)
randn_losses = torch.mean(all_losses_tensor, dim=0)

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


In [34]:
# random uniform vector

randu_all_losses = []
for _ in range(10):
    randu_vec = torch.rand_like(anger_sae_vec)
    randu_losses = losses_for_norms(norms, randu_vec)
    randu_all_losses.append(randu_losses)
all_randu_losses_tensor = torch.tensor(randu_all_losses)
randu_losses = torch.mean(all_randu_losses_tensor, dim=0)

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


In [36]:
sae_blog_vec = sae6.W_dec[14586] * 96  # blog post
sae_blog_vec = sae_blog_vec[None, None, :]

sae_blog_losses = losses_for_norms(norms, sae_blog_vec)
sae_blog_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.043927245140076,
 2.0646593236923216,
 2.1813218903541567,
 2.4917372703552245,
 3.2092547607421875,
 4.323053164482117,
 5.482840032577514]

In [38]:

# Create a figure
fig = go.Figure()

# Add traces for each list
fig.add_trace(go.Scatter(x=norms, y=anger_sae_losses, mode='lines+markers', name='sae anger ft'))
fig.add_trace(go.Scatter(x=norms, y=act_diff_losses, mode='lines+markers', name='act diff anger-calm'))
fig.add_trace(go.Scatter(x=norms, y=sae_sum_losses, mode='lines+markers', name='sae anger+blog ft'))
fig.add_trace(go.Scatter(x=norms, y=sae_blog_losses, mode='lines+markers', name='sae blog ft'))
# fig.add_trace(go.Scatter(x=norms, y=randn_losses, mode='lines+markers', name='randn vector'))
# fig.add_trace(go.Scatter(x=norms, y=randu_losses, mode='lines+markers', name='randu vector'))

# Set the title and labels
fig.update_layout(title='Loss vs norm for different steering vectors',
                  xaxis_title='Norm',
                  yaxis_title='Loss')

# Show the figure
fig.show()

In [43]:
# random sae vector

# but first, we need to look at effective norms of sae vectors
# effective norm is enc norm * dec norm

print(sae6.W_enc.norm(dim=0).shape)
print(sae6.W_dec.norm(dim=1).shape)

sae_norms = sae6.W_enc.norm(dim=0) * sae6.W_dec.norm(dim=1)

px.histogram(sae_norms.cpu().numpy(), title="Effective norms of SAE vectors").show()


torch.Size([16384])
torch.Size([16384])


In [45]:
# select sae idxs where effective norm > 1
sae_idxs = torch.where(sae_norms > 1)[0]
print(sae_idxs)

tensor([    0,     1,     3,  ..., 16374, 16376, 16382], device='cuda:0')


In [53]:
sae_all_losses = []
for idx in sae_idxs[:50]:
    sae_vec = sae6.W_dec[idx][None, None, :]
    sae_losses = losses_for_norms([100], sae_vec)
    sae_all_losses.append(sae_losses[0])


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.48s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.48s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.49s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.48s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.49s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]


In [59]:
sae_all_norms = sae_norms[sae_idxs[:50]].cpu().numpy()
sae_all_norms

array([1.7750012, 3.0004005, 2.1533155, 1.8475288, 3.741389 , 4.1725993,
       2.3062062, 3.013856 , 1.7942193, 3.9857953, 2.0277243, 3.5320127,
       1.6898198, 1.7474426, 1.8993434, 2.913948 , 1.590908 , 4.4241714,
       1.9618442, 1.8986557, 3.7316065, 1.4529527, 1.9260013, 1.5114743,
       1.3631194, 1.9415872, 3.010469 , 3.9565852, 1.757123 , 3.7903998,
       3.142177 , 1.9863118, 2.9711437, 2.2604647, 3.1986165, 2.3528612,
       3.066518 , 4.2257867, 2.5939338, 2.516755 , 4.2358904, 1.669474 ,
       2.5128818, 3.9561095, 2.940383 , 3.518045 , 2.9661024, 1.4918803,
       2.1688523, 1.4764249], dtype=float32)

In [60]:
sae_all_losses

[4.260148124694824,
 5.038935794830322,
 4.680075588226319,
 4.267725758552551,
 4.5472962760925295,
 4.308347868919372,
 4.51309377193451,
 4.255828490257263,
 11.663812427520751,
 4.349827694892883,
 6.100510425567627,
 4.071781392097473,
 8.749277772903442,
 5.850289335250855,
 6.910925855636597,
 6.323086729049683,
 4.288483939170837,
 4.052916049957275,
 5.183176898956299,
 4.7983832359313965,
 4.237877340316772,
 5.426720972061157,
 5.978160200119018,
 4.976030168533325,
 4.123554253578186,
 6.133593406677246,
 3.9421783542633055,
 4.496182246208191,
 4.317424373626709,
 4.589922571182251,
 4.001005959510803,
 8.739752407073974,
 5.045418310165405,
 4.750954914093017,
 4.29761604309082,
 7.065834188461304,
 6.096748895645142,
 16.19818546295166,
 5.392672328948975,
 6.851850700378418,
 3.9828504800796507,
 4.418928847312928,
 4.660826230049134,
 4.726483907699585,
 5.485279197692871,
 4.315899586677551,
 3.7457772970199583,
 4.049683766365051,
 6.436926937103271,
 4.4065652179718

In [61]:
px.scatter(x=sae_all_norms, y=sae_all_losses, title="Loss vs effective norm for SAE vectors").show()

In [62]:
# trying some generations

In [65]:
# blog post steering
blog_post_vec = sae6.W_dec[14586] * 96  # blog post
blog_post_vec = blog_post_vec[None, None, :]

gen_blog = generate(model, hp6, "I think", blog_post_vec, scale=0.8, insertion_pos=None, n_samples=10)
print(gen_blog)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:10<00:00,  1.05s/it]

['I think I just came across this post on a blog. “The way I started about 2 years ago', 'I think the original blog that this author is referring to has been shut down.\n\nI’m still going', 'I think that my recent post <em><strong>“How to Blog About Personal Health”</strong></em>  has started', 'I think many people (who are not bloggers ) like to read, rather than to create a post  about', 'I think if you want to just write a blog about your photos on your old computer, or whatever you want', 'I think if we could put our <b>blog posts</b> under the <b>Blog /</b> tab.\n\n', "I think you're talking about:\n\nhttp://www.x390.org/archives/", "I think it is not too much of an issue any way. I's interesting to see what goes on", "I think the whole thing is just so incredibly stupid, don's your really think that it would stop you", 'I think this is the only place where on-line or on-radio, with, by, in,']





In [73]:
# combo steering
combo_vec = sae6.W_dec[1062] * 60 + sae6.W_dec[14586] * 50  # anger + blog post
combo_vec = combo_vec[None, None, :]

gen_combo = generate(model,
                     hp6,
                     "I think",
                     combo_vec,
                     scale=0.8,
                     insertion_pos=None,
                     n_samples=10,
                     max_length=30,
                     )
print(gen_combo)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:15<00:00,  1.58s/it]

['I think that in a recent conversation with a customer, they said "I\'m so mad I can\'t even think at the moment, I would like', "I think it's about time!  Today I'm going to go through all the things, and then I'm going to be.   :)", 'I think that you can make many in comments or emails, but I was just reading through an email about a couple of friends I had with about the "politics', 'I think that it would not be fair if i didn\'t comment about a new issue. Today I was listening news about 2 months, because "The', "I think people are very angry about the climate change, but that's not a good thing. When we were 6 years old, the same, when", 'I think i have written a fair bit lately, but not angry or upset because i was right but just ranting and that i was like… i hope i got', "I think you're just jealous that you didn't get the 2nd half instead of your girlfriend.\n\nA.B. - it's", 'I think you need the "s" word.\n\nBecause I feel very angry at being told to to that.\n\nI don t want




['I think this post is about more than just about about posts.\n\nThe word "blogger" is probably,', 'I think your anger at the way of you has been for years. About the way of angry, is just', 'I think I can write and a lot now and on twitter – that’ post, or my old blog (', 'I think this a really good blog, , you need to a about it post\n\nI was about to blog', 'I think I posted here in about last summer that about my experience with about the.\n\nMy angry with about', 'I think that is a fair comment because it made us feel like about something. about something. that’s', 'I think the fact about that, is how can you write my post when he is about about some , I', 'I think I am a hot-head,\nand I write about a lot of…\nwith. ,', 'I think a lot about not.posts.about and it doesn today after the about post. This is not', "I think this question' is too much . because there is really this and...\n\nblog... ................................"]
