## Setup

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm

torch.set_grad_enabled(False);

Collecting sae-lens
  Downloading sae_lens-3.11.0-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.4/84.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformer-lens
  Downloading transformer_lens-2.1.0-py3-none-any.whl (154 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting automated-interpretability<0.0.4,>=0.0.3 (from sae-lens)
  Downloading automated_interpretability-0.0.3-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl (6.9 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m8.4 MB/s[0m eta [36m

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface.

In [3]:
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
    device = device
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


blocks.8.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

In [4]:
sae.W_enc.shape

torch.Size([768, 24576])

# get steering vec- one layer

In [5]:
# pass in all at once
prompts = ["love",
           "hate"]
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens.shape

torch.Size([2, 2])

In [6]:
model.reset_hooks(including_permanent=True)
_, cache = model.run_with_cache(tokens)

In [7]:
hook_point = "blocks.6.hook_resid_pre"  # saelens only has pre, not post
layer = 6
steering_vec = cache[hook_point][0, :, :] - cache[hook_point][1, :, :]
steering_vec = steering_vec.unsqueeze(0)

# steer by add hook using partial

In [8]:
from functools import partial

def act_add(
    activation,
    hook,
    steering_vec,
    # initPromptLen
):
    # activation[:, initPromptLen:, :] += steering_vec[:, -1, :] * 3
    activation[:, -1, :] += steering_vec[:, -1, :] * 3
    return activation

hook_fn = partial(
        act_add,
        steering_vec=steering_vec,
        # initPromptLen=initPromptLen
    )

# initPromptLen = len(model.tokenizer.encode("I think cats are "))

# cache L6 to L11 actvs for unst vs st

In [9]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

Downloading readme:   0%|          | 0.00/373 [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/921 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors


In [10]:
batch_tokens = token_dataset[:32]["tokens"]
_, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

In [11]:
# test_sentence = "I think cats are "
# tokens = model.to_tokens(test_sentence, prepend_bos=True)

In [12]:
model.reset_hooks(including_permanent=True)
_, unst_cache = model.run_with_cache(batch_tokens)

In [13]:
model.reset_hooks(including_permanent=True)
model.add_hook(hook_point, hook_fn)
_, steered_cache = model.run_with_cache(batch_tokens)

In [14]:
# unst_cache_dict = {k: v for k, v in unst_cache.items()}
# steered_cache_dict = {k: v for k, v in steered_cache.items()}

In [19]:
unst_cache[hook_point].shape

torch.Size([32, 128, 768])

# obtain SAE feature actvs each L

In [15]:
%%capture
unst_feature_acts = {}
steered_feature_acts = {}

for layer_id in range(5, 11):
    hook_point = f'blocks.{layer_id}.hook_resid_pre'
    # print(hook_point)
    # sparse_autoencoder = load_sae(hook_point)
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = hook_point, # won't always be a hook point
        device = device
    )

    sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
    with torch.no_grad():
        # sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        #     unst_cache[hook_point]
        # )
        feature_acts = sae.encode(unst_cache[sae.cfg.hook_name])
        # sae_out = sae.decode(feature_acts)
    unst_feature_acts[hook_point] = feature_acts

    # print('\n')
    sae.eval()
    with torch.no_grad():
        # sae_out, feat_acts_steered, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        #     steered_cache[hook_point]
        # )
        feature_acts = sae.encode(steered_cache[sae.cfg.hook_name])
    steered_feature_acts[hook_point] = feature_acts

# save actvs

In [16]:
import pickle
with open('unst_feature_acts.pkl', 'wb') as f:
    pickle.dump(unst_feature_acts, f)
with open('steered_feature_acts.pkl', 'wb') as f:
    pickle.dump(steered_feature_acts, f)

In [20]:
from google.colab import drive
drive.mount('/content/drive')

!cp unst_feature_acts.pkl /content/drive/MyDrive/
!cp steered_feature_acts.pkl /content/drive/MyDrive/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
