In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [8]:
N_CONTEXT = 1024 # don't change this: it's the max context length, used by the original Gemma-2 dataset
N_BATCHES = 128

USE_BASE_TOKENIZER=True

BASE_MODEL = "google/gemma-2b"
FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'

# this is a tokenized language dataset that the base model's SAE was originally trained on
DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
# Or if we want to use Pile:
# DATASET_NAME = "monology/pile-uncopyrighted"

In [3]:
from datasets import load_dataset

# Load the dataset in streaming mode
dataset = load_dataset(DATASET_NAME, split="train", streaming=True)

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

In [4]:
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

if not USE_BASE_TOKENIZER:
    finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
    finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)

    finetune_model = HookedSAETransformer.from_pretrained('gemma-2b', device='cpu', hf_model=finetune_model_hf)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [9]:
from sae_lens import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=BASE_MODEL,
    dataset_path=DATASET_NAME,

    is_dataset_tokenized=True,
    streaming=True,

    # Activation Store Parameters
    store_batch_size_prompts=N_BATCHES,
    context_size=N_CONTEXT,

    # Misc
    device=device,
    seed=42,
)

Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 2.62144
Lower bound: n_contexts_per_buffer (millions): 0.00256
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 8388.608
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


In [10]:
from sae_lens import ActivationsStore

# Instantiate an activation store to easily sample tokenized batches from our dataset
activation_store = ActivationsStore.from_config(
    model=base_model,
    cfg=cfg
)

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

sae_lens.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

In [11]:
juicy_data = activation_store.get_batch_tokens()

In [12]:
juicy_data.shape

torch.Size([128, 1024])