# Loading and Analysing Pre-Trained Sparse Autoencoders

In [11]:
from google.colab import drive
import shutil

drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
import pickle

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
# from tqdm import tqdm
# import plotly.express as px

# # Imports for displaying vis in Colab / notebook
# import webbrowser
# import http.server
# import socketserver
# import threading
# PORT = 8000

torch.set_grad_enabled(False);

Collecting sae-lens
  Downloading sae_lens-3.17.1-py3-none-any.whl.metadata (4.8 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.4.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.5-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly<6.0.0,>=5.19.0 (from sae-lens)
  Downloading plotly-5.23.0-py3-none-any.whl.metadata (7.3 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# hf login

In [4]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# load pretrained SAEs

In [5]:
layer_name = "blocks.0.hook_mlp_out"

In [6]:
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small", device = device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
from datasets import load_dataset
from sae_lens import SAE

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-mlp-out-v5-32k", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = layer_name, # won't always be a hook point
    device = device
)

v5_32k_layer_0/cfg.json:   0%|          | 0.00/532 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/131k [00:00<?, ?B/s]

## save decoder weights

In [8]:
weight_matrix_np = sae.W_dec.cpu()

In [10]:
Wdec_filename = 'gpt2sm_mlp0_Wdec.pkl'
with open(Wdec_filename, 'wb') as f:
    pickle.dump(weight_matrix_np, f)

# source_path = f'/path/to/your/file/{file_name}'
source_path = Wdec_filename
# dest_folder = ''
destination_path = f'/content/drive/MyDrive/sae_files/{Wdec_filename}'

shutil.copy(source_path, destination_path) # Copy the file

'/content/drive/MyDrive/sae_files/gpt2sm_mlp0_Wdec.pkl'

# load dataset

## get data

Need load model tokenizer before obtain dataset

In [None]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset("roneneldan/TinyStories", streaming=False)
test_dataset = dataset['validation']

token_dataset = tokenize_and_concatenate(
    dataset = test_dataset,
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    # max_length=sae.cfg.context_size,
    max_length=128,
    # add_bos_token=sae.cfg.prepend_bos,
    add_bos_token=False,
)

Downloading readme:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (10434 > 2048). Running this sequence through the model will result in indexing errors


In [None]:
batch_tokens = token_dataset[:500]["tokens"]
batch_tokens.shape

torch.Size([500, 128])

In [9]:
save_data_fn = 'batch_tokens_anySamps_v1.pkl'

## save selected data

In [None]:
with open(save_data_fn, 'wb') as f:
    pickle.dump(batch_tokens, f)

# source_path = f'/path/to/your/file/{file_name}'
source_path = save_data_fn
# dest_folder = ''
destination_path = f'/content/drive/MyDrive/{save_data_fn}'

shutil.copy(source_path, destination_path) # Copy the file

## load selected data

In [14]:
# check if saved
file_path = '/content/drive/MyDrive/sae_files/' + save_data_fn
with open(file_path, 'rb') as f:
    batch_tokens = pickle.load(f)

In [15]:
batch_tokens.shape

torch.Size([500, 128])

# save sae actvs

## get LLM actvs

In [16]:
h_store = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model.cfg.d_model), device=model.cfg.device)
h_store.shape

torch.Size([500, 128, 768])

In [17]:
from torch import nn, Tensor
# import torch.nn.functional as F
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple

def store_h_hook(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store[:] = pattern  # this works b/c changes values, not replaces entire thing

In [18]:
model.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook),
    ]
)

## get SAE actvs

In [19]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts = sae.encode(h_store)

Now you have to save actvs, bc saelens not compatible with umap lib

In [20]:
import pickle
with open('fActs_GPT2sm_MLP0.pkl', 'wb') as f:
    pickle.dump(feature_acts, f)

In [22]:
test=1
with open('test.pkl', 'wb') as f:
    pickle.dump(test, f)

fActs_filename = 'test.pkl'
# source_path = f'/path/to/your/file/{file_name}'
source_path = fActs_filename
# dest_folder = ''
destination_path = f'/content/drive/MyDrive/sae_files/{fActs_filename}'

shutil.copy(source_path, destination_path) # Copy the file

'/content/drive/MyDrive/sae_files/test.pkl'

In [21]:
fActs_filename = 'fActs_GPT2sm_MLP0.pkl'
# source_path = f'/path/to/your/file/{file_name}'
source_path = fActs_filename
# dest_folder = ''
destination_path = f'/content/drive/MyDrive/sae_files/{fActs_filename}'

shutil.copy(source_path, destination_path) # Copy the file

'/content/drive/MyDrive/sae_files/fActs_GPT2sm_MLP0.pkl'

In [22]:
# !cp fActs_GPT2sm_MLP0.pkl /content/drive/MyDrive/sae_files/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
# check if saved
file_path = '/content/drive/MyDrive/sae_files/' + 'fActs_GPT2sm_MLP0.pkl'
with open(file_path, 'rb') as f:
    feature_acts = pickle.load(f)

In [24]:
feature_acts.shape

torch.Size([500, 128, 32768])