# Loading and Analysing Pre-Trained Sparse Autoencoders

## Imports & Installs

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

Collecting sae-lens
  Downloading sae_lens-3.13.0-py3-none-any.whl.metadata (4.5 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.2.2-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<0.0.4,>=0.0.3 (from sae-lens)
  Downloading automated_interpretability-0.0.3-py3-none-any.whl.metadata (817 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly<6.0.0,>=5.19.0 (from sae-lens)
  Downloading plotly-5.23.0-py3-none-any.whl.metadata (7.3 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 

## Set Up

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# hf login

In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# load pretrained SAEs

In [4]:
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE

# model = HookedTransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
    device = device
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


blocks.8.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

In [None]:
# sae_2, cfg_dict_2, sparsity_2 = SAE.from_pretrained(
#     release = "gpt2-small-resid-post-v5-32k", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.7.hook_resid_post", # won't always be a hook point
#     device = device
# )

cfg.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/131k [00:00<?, ?B/s]

In [None]:
model_2_layer = 'blocks.12.hook_resid_post'

In [None]:
model_2 = HookedTransformer.from_pretrained("gemma-2b", device = device)
sae_2, cfg_dict_2, sparsity_2 = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.12.hook_resid_post", # won't always be a hook point
    device = device
)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2b into HookedTransformer


(…)blocks.12.hook_resid_post_16384/cfg.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/65.6k [00:00<?, ?B/s]

## save decoder weights

In [5]:
from google.colab import drive
import shutil

drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
weight_matrix_np = sae.W_dec.cpu()

In [8]:
import pickle

In [9]:
Wdec_filename = 'gpt2-small-8-res-jb_Wdec.pkl'
with open(Wdec_filename, 'wb') as f:
    pickle.dump(weight_matrix_np, f)

# source_path = f'/path/to/your/file/{file_name}'
source_path = Wdec_filename
# dest_folder = ''
destination_path = f'/content/drive/MyDrive/{Wdec_filename}'

shutil.copy(source_path, destination_path) # Copy the file

'/content/drive/MyDrive/gpt2-small-8-res-jb_Wdec.pkl'

# load dataset

In [None]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model_2.tokenizer, # type: ignore
    streaming=True,
    max_length=sae_2.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [None]:
batch_tokens = token_dataset[:32]["tokens"]
batch_tokens.shape

# model 1- save sae actvs

## get LLM actvs

In [None]:
layer_name = 'blocks.8.hook_resid_pre'

In [None]:
h_store = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model.cfg.d_model), device=model.cfg.device)
h_store.shape

torch.Size([32, 128, 768])

In [None]:
from torch import nn, Tensor
# import torch.nn.functional as F
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple

def store_h_hook(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store[:] = pattern  # this works b/c changes values, not replaces entire thing

In [None]:
model.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook),
    ]
)

## get SAE actvs

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts = sae.encode(h_store)

Now you have to save actvs, bc saelens not compatible with umap lib

In [None]:
import pickle
with open('feature_acts_model_A.pkl', 'wb') as f:
    pickle.dump(feature_acts, f)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp feature_acts_model_A.pkl /content/drive/MyDrive/

Mounted at /content/drive


# model 2- save sae actvs

## get LLM actvs

In [None]:
layer_name = model_2_layer

In [None]:
h_store_2 = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model_2.cfg.d_model), device=model_2.cfg.device)
h_store_2.shape

torch.Size([32, 128, 2048])

In [None]:
def store_h_hook_2(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store_2[:] = pattern  # this works b/c changes values, not replaces entire thing

In [None]:
model_2.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook_2),
    ]
)

## get SAE actvs

In [None]:
sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts_2 = sae_2.encode(h_store_2)

Now you have to save actvs, bc saelens not compatible with umap lib

In [None]:
with open('feature_acts_model_B.pkl', 'wb') as f:
    pickle.dump(feature_acts_2, f)

In [None]:
!cp feature_acts_model_B.pkl /content/drive/MyDrive/

# gemma 12b, L6- save sae actvs

## setup

In [None]:
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE

In [None]:
from torch import nn, Tensor
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple

In [None]:
import pickle

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## laod model

In [None]:
model_2_layer = 'blocks.6.hook_resid_post'

In [None]:
model_2 = HookedTransformer.from_pretrained("gemma-2b", device = device)
sae_2, cfg_dict_2, sparsity_2 = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.12.hook_resid_post", # won't always be a hook point
    device = device
)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:  26%|##6       | 1.30G/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2b into HookedTransformer


(…)blocks.12.hook_resid_post_16384/cfg.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/65.6k [00:00<?, ?B/s]

## get data

In [None]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model_2.tokenizer, # type: ignore
    streaming=True,
    max_length=sae_2.cfg.context_size,
    add_bos_token=sae_2.cfg.prepend_bos,
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
batch_tokens = token_dataset[:32]["tokens"]
batch_tokens.shape

torch.Size([32, 1024])

## get LLM actvs

In [None]:
layer_name = model_2_layer

In [None]:
h_store_2 = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model_2.cfg.d_model), device=model_2.cfg.device)
h_store_2.shape

torch.Size([32, 1024, 2048])

In [None]:
def store_h_hook_2(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store_2[:] = pattern  # this works b/c changes values, not replaces entire thing

In [None]:
model_2.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook_2),
    ]
)

## get SAE actvs

In [None]:
sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts_2 = sae_2.encode(h_store_2)

Now you have to save actvs, bc saelens not compatible with umap lib

In [None]:
with open('feature_acts_model_B_L6.pkl', 'wb') as f:
    pickle.dump(feature_acts_2, f)

In [None]:
!cp feature_acts_model_B_L6.pkl /content/drive/MyDrive/

In [None]:
file_path = '/content/drive/MyDrive/feature_acts_model_B_L6.pkl'
with open(file_path, 'rb') as f:
    feature_acts_model_B = pickle.load(f)