In [10]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fda55afceb0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [20]:
norms = [0, 10, 20, 40, 60, 80, 100, 120]

In [21]:
def losses_for_norms(norms, vector):
    normed_vector = vector / vector.norm()
    losses = get_loss(model, hp6, normed_vector, scales=norms, insertion_pos=None)
    return losses

In [22]:
anger_sae_vec = sae6.W_dec[1062] * 56  # anger
anger_sae_vec = anger_sae_vec[None, None, :]

anger_sae_losses = losses_for_norms(norms, anger_sae_vec)
anger_sae_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [00:59<00:00,  7.49s/it]


[2.037747383117676,
 2.0441823053359984,
 2.0720777678489686,
 2.2562749361991883,
 2.704720501899719,
 3.5302905654907226,
 4.782238540649414,
 6.302010145187378]

In [23]:
act_vec = get_activation_steering(model, hp6, pos_text="Anger", neg_text="Calm")[0, 1, :]
act_vec = act_vec[None, None, :]

act_diff_losses = losses_for_norms(norms, act_vec)
act_diff_losses

pad tensor([], device='cuda:0', size=(1, 0), dtype=torch.int64)
loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.0471405363082886,
 2.08109929561615,
 2.2730241632461547,
 2.8444892740249634,
 4.036533904075623,
 5.590423221588135,
 7.094773988723755]

In [25]:
sae_sum_vec = sae6.W_dec[1062] * 56 + sae6.W_dec[14586] * 96  # anger + blog post
sae_sum_vec = sae_sum_vec[None, None, :]

sae_sum_losses = losses_for_norms(norms, sae_sum_vec)
sae_sum_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.0423549485206602,
 2.0587640690803526,
 2.1463862109184264,
 2.3766694688796997,
 2.8935400915145872,
 3.811956706047058,
 4.950247449874878]

In [33]:
# random normal vector

randn_all_losses = []
for _ in range(10):
    randn_vec = torch.randn_like(anger_sae_vec)
    randn_losses = losses_for_norms(norms, randn_vec)
    randn_all_losses.append(randn_losses)
all_losses_tensor = torch.tensor(randn_all_losses)
randn_losses = torch.mean(all_losses_tensor, dim=0)

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.50s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.52s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.53s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


In [34]:
# random uniform vector

randu_all_losses = []
for _ in range(10):
    randu_vec = torch.rand_like(anger_sae_vec)
    randu_losses = losses_for_norms(norms, randu_vec)
    randu_all_losses.append(randu_losses)
all_randu_losses_tensor = torch.tensor(randu_all_losses)
randu_losses = torch.mean(all_randu_losses_tensor, dim=0)

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.54s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.55s/it]


loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.56s/it]


In [36]:
sae_blog_vec = sae6.W_dec[14586] * 96  # blog post
sae_blog_vec = sae_blog_vec[None, None, :]

sae_blog_losses = losses_for_norms(norms, sae_blog_vec)
sae_blog_losses

loading dataset: NeelNanda/c4-code-20k
dataset loaded


100%|██████████| 8/8 [01:00<00:00,  7.51s/it]


[2.037747383117676,
 2.043927245140076,
 2.0646593236923216,
 2.1813218903541567,
 2.4917372703552245,
 3.2092547607421875,
 4.323053164482117,
 5.482840032577514]

In [38]:

# Create a figure
fig = go.Figure()

# Add traces for each list
fig.add_trace(go.Scatter(x=norms, y=anger_sae_losses, mode='lines+markers', name='sae anger ft'))
fig.add_trace(go.Scatter(x=norms, y=act_diff_losses, mode='lines+markers', name='act diff anger-calm'))
fig.add_trace(go.Scatter(x=norms, y=sae_sum_losses, mode='lines+markers', name='sae anger+blog ft'))
fig.add_trace(go.Scatter(x=norms, y=sae_blog_losses, mode='lines+markers', name='sae blog ft'))
# fig.add_trace(go.Scatter(x=norms, y=randn_losses, mode='lines+markers', name='randn vector'))
# fig.add_trace(go.Scatter(x=norms, y=randu_losses, mode='lines+markers', name='randu vector'))

# Set the title and labels
fig.update_layout(title='Loss vs norm for different steering vectors',
                  xaxis_title='Norm',
                  yaxis_title='Loss')

# Show the figure
fig.show()