In [6]:
import torch

import os
import json

from pos_sae.model import SparseAutoencoder
from pos_sae.compute_dead import get_freq_single_sae

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import plotly.express as px
from plotly.subplots import make_subplots

import pandas as pd

from datasets import load_dataset
from torch.utils.data import DataLoader

In [2]:
dir_path = "/Users/slava/fun/pos_sae/converted_checkpoints/final_sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_24576"

model = SparseAutoencoder.load_from_pretrained(dir_path=dir_path)

model.to("mps")


Run name: 24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08
ignored keys: ['hook_point_head_index', 'dataset_path', 'is_dataset_tokenized', 'use_cached_activations', 'cached_activations_path', 'n_batches_in_buffer', 'total_training_tokens', 'store_batch_size', 'seed', 'b_dec_init_method', 'use_ghost_grads', 'feature_sampling_window', 'dead_feature_window', 'wandb_entity', 'n_checkpoints']


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

In [3]:

gpt2 = HookedTransformer.from_pretrained("gpt2-small")


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
freqs = get_freq_single_sae(model, gpt2, n_batches=100).cpu()
log_freq = torch.log10(freqs + 1e-6)

print("Frequency Statistics:")
print(f"  Mean: {freqs.mean():.4f}")
print(f"  Std: {freqs.std():.4f}")
print(f"  Maximum: {freqs.max():.4f}")
print(f"  Minimum: {freqs.min():.4f}")


fig = px.histogram(x=log_freq, nbins=100)
fig.update_layout(
    title='SAE Feature Frequency Histogram',
    xaxis_title='Frequency (Log Scale)',
    yaxis_title='Count'
)
fig.show()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 100/100 [00:07<00:00, 12.81it/s]


Frequency Statistics:
  Mean: 0.0019
  Std: 0.0080
  Maximum: 0.3967
  Minimum: 0.0000


In [10]:
f_per_step = get_freq_single_sae(model, gpt2, n_batches=1000, per_step=True).cpu()
log_f_per_step = torch.log10(f_per_step + 1e-6)
print(f_per_step.shape) # (32, 24576)


The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.

100%|██████████| 1000/1000 [01:05<00:00, 15.16it/s]

torch.Size([32, 24576])





In [11]:
# Create a subplot grid
rows = 3
cols = 2
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f"Step {step}" for step in [1, 2, 4, 8, 16, 32]])

# Plot histograms for each step
steps = [1, 2, 4, 8, 16, 32]
for i, step in enumerate(steps):
    row = i // cols + 1
    col = i % cols + 1
    
    fig.add_trace(
        px.histogram(x=log_f_per_step[step-1, :], nbins=100).data[0],
        row=row,
        col=col
    )

# Update layout
fig.update_layout(
    title='SAE Feature Frequency Histogram',
    xaxis_title='Frequency (Log Scale)',
    yaxis_title='Count',
    showlegend=False
)

# Adjust the height of the figure based on the number of rows
fig.update_layout(height=400 * rows)

# Show the plot
fig.show()