In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Collecting sae-lens
  Downloading sae_lens-3.19.3-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.4.1-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly<6.0.0,>=5.19.0 (from sae-lens)
  Downloading plotly-5.24.0-py3-none-any.whl.metadata (7.3 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 

In [11]:
N_CONTEXT = 128 # don't change this: it's the max context length, used by the original Gemma-2 dataset
N_BATCHES = 256
TOTAL_BATCHES = 1

USE_BASE_TOKENIZER=True

# BASE_MODEL = "google/gemma-2b"
BASE_MODEL = "gpt2-small"
FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'

# this is a tokenized language dataset that the base model's SAE was originally trained on
DATASET_NAME = "Skylion007/openwebtext"

# DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"

# Or if we want to use Pile:
# DATASET_NAME = "monology/pile-uncopyrighted"

In [3]:
LAYER_NUM = 6
SAE_LAYER = f'blocks.{LAYER_NUM}.hook_resid_pre'

In [4]:
from datasets import load_dataset

# Load the dataset in streaming mode
dataset = load_dataset(DATASET_NAME, split="train", streaming=True)

Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] Y


In [5]:
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


In [6]:
from sae_lens import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=BASE_MODEL,
    dataset_path=DATASET_NAME,

    # is_dataset_tokenized=True,
    streaming=True,

    # Activation Store Parameters
    store_batch_size_prompts=N_BATCHES,
    context_size=N_CONTEXT,

    # Misc
    device=device,
    seed=42,
)

Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.65536
Lower bound: n_contexts_per_buffer (millions): 0.00512
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


In [8]:
from sae_lens import ActivationsStore

# Instantiate an activation store to easily sample tokenized batches from our dataset
activation_store = ActivationsStore.from_config(
    model=base_model,
    cfg=cfg
)



1. Assume there’s an input tensors of shape `[N_BATCH, N_CONTEXT]`, containing a sample from the dataset in Task 1.
2. Run the base model with [run_with_cache](https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookedRootModule.run_with_cache) on this input tensor, storing the activations in the cache object as it’s done on the Github notebook.
3. Save the activations tensor using torch.save
4. Similarly, run the finetune on the input tensor, storing & saving the activations
5. Return (save) two tensors of the shape `[N_BATCH, N_CONTEXT, N_ACTIVATIONS]`, one for each model.

In [9]:
all_acts = []

for k in range(TOTAL_BATCHES):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens() # [N_BATCH, N_CONTEXT]

    _, cache = base_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                         names_filter=[SAE_LAYER]) # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all feature activations into a single tensor
all_acts = torch.cat(all_acts)
all_acts.shape

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([256, 128, 768])

In [10]:
torch.save(all_acts, "base_acts.pt")

# Offload the first model from memory
del base_model
torch.cuda.empty_cache()

In [16]:
# finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf)



Loaded pretrained model gpt2-small into HookedTransformer


In [17]:
all_acts = []

for k in range(TOTAL_BATCHES):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens() # [N_BATCH, N_CONTEXT]

    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                         names_filter=[SAE_LAYER]) # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all feature activations into a single tensor
all_acts = torch.cat(all_acts)
all_acts.shape

torch.Size([256, 128, 768])

In [18]:
torch.save(all_acts, "finetune_acts.pt")