## Imports

In [None]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias
import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset

import sae_lens
import transformer_lens
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate
from tqdm import tqdm

from transformer_lens.utils import tokenize_and_concatenate

from transformer_lens import utils
from functools import partial
from sae_lens.evals import *
from torch.nn.utils.rnn import pad_sequence

device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# device1 = t.device("cuda:1" if t.cuda.is_available() else "cpu")
# device2 = t.device("cuda:2" if t.cuda.is_available() else "cpu")
# print(device1, device2)
# gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device1)
# Hugging face: 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
# Wandb: 6b549d940e7a29c79c184f27f25606e94a48a966

## Q. How different is an SAE trained on wanda-pruned gpt2-small from a wanda-pruned SAE trained on gpt2-small?

#### Reconstruction loss

In [None]:
t.set_grad_enabled(False)

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

In [None]:
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device1)

# # MATS
# hf_repo_id = "gpt2-small-hook-z-kk"
# sae_id = "blocks.9.hook_z"
# pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
# pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_attn_sae_wanda.pth'))

# Trained by me
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_sae_wanda.pth'))


dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=gpt2.tokenizer,  # type: ignore
    streaming=True,
    max_length=pruned_sae_trained_on_gpt2.cfg.context_size,
    add_bos_token=pruned_sae_trained_on_gpt2.cfg.prepend_bos,
)

pruned_sae_trained_on_gpt2.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:16]["tokens"]
    _, cache = gpt2.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = pruned_sae_trained_on_gpt2.encode(cache[pruned_sae_trained_on_gpt2.cfg.hook_name])
    sae_out = pruned_sae_trained_on_gpt2.decode(feature_acts)

    # save some room
    del cache

    print(
    "Reconstuction loss:",
    gpt2.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                pruned_sae_trained_on_gpt2.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

In [None]:
# SAE trained on wanda-pruned gpt2-small

gpt2_pruned = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device2)
gpt2_pruned.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))

# # MATS trained config
# hf_repo_id = "suchitg/sae_test"
# sae_id = 'blocks.9.attn.hook_z-attn-sae-v1'
# sae_trained_on_gpt2_pruned = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device2))[0]

# Trained by me
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
sae_trained_on_gpt2_pruned = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device2))[0]

sae_trained_on_gpt2_pruned.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:16]["tokens"]
    _, cache = gpt2_pruned.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae_trained_on_gpt2_pruned.encode(cache[sae_trained_on_gpt2_pruned.cfg.hook_name])
    sae_out = sae_trained_on_gpt2_pruned.decode(feature_acts)

    # save some room
    del cache


    print(
    "Reconstuction loss:",
    gpt2_pruned.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae_trained_on_gpt2_pruned.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

#### Other method to evaluate

In [None]:
# Wanda-pruned SAE trained on gpt2-small
from evals import multiple_evals as ME

hf_repo_id = "gpt2-small-hook-z-kk"
sae_id = "blocks.9.hook_z"
pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_attn_sae_wanda.pth'))
eval_results = ME(pruned_sae_trained_on_gpt2, 20, 200, 16, datasets=['Skylion007/openwebtext'], output_dir="out/pruned_sae_trained_on_gpt2",verbose=True, load=False, path_to_load_from=None)


In [None]:
from evals import process_results as PR

output_files = PR(eval_results, 'out/pruned_sae_trained_on_gpt2')
print("Evaluation complete. Output files:")
print(f"Individual JSONs: {len(output_files['individual_jsons'])}")  # type: ignore
print(f"Combined JSON: {output_files['combined_json']}")
print(f"CSV: {output_files['csv']}")

In [None]:
res = pd.read_csv('out/pruned_sae_trained_on_gpt2/all_eval_results.csv')
res['metrics.reconstruction_quality.mse']

In [None]:
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-attn-sae-v1'
sae_trained_on_gpt2_pruned = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device2))[0]
eval_results = ME(sae_trained_on_gpt2_pruned, 20, 200, 16, datasets=['Skylion007/openwebtext'], output_dir="out/sae_trained_on_gpt2_pruned",verbose=True, load=False, path_to_load_from=None)

In [None]:
from evals import process_results as PR

output_files = PR(eval_results, 'out/sae_trained_on_gpt2_pruned')
print("Evaluation complete. Output files:")
print(f"Individual JSONs: {len(output_files['individual_jsons'])}")  # type: ignore
print(f"Combined JSON: {output_files['combined_json']}")
print(f"CSV: {output_files['csv']}")

In [None]:
res = pd.read_csv('out/sae_trained_on_gpt2_pruned/all_eval_results.csv')
res['metrics.reconstruction_quality.mse']

#### Try reconstruction loss from scratch

In [None]:
@torch.no_grad()
def get_reconstruction_loss(sae, model, dataset, batch_size=8):
    sae.eval()
    hook_name = sae.cfg.hook_name
    # print(hook_name, head_index)
    def reconstr_hook(activation, hook, sae_out):
        return sae_out
    
    n_batches = dataset.num_rows // batch_size
    loss = 0

    for batch in tqdm(range(n_batches)):
        with t.no_grad():
            batch_tokens = token_dataset[batch * batch_size : (batch + 1) * batch_size]["tokens"]
            _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
            feature_acts = sae.encode(cache[hook_name])
            sae_out = sae.decode(feature_acts)
            del cache

            loss += model.run_with_hooks(batch_tokens, fwd_hooks=[(hook_name, partial(reconstr_hook, sae_out=sae_out))], return_type="loss").item()
    
    return loss/n_batches

In [None]:
dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=gpt2.tokenizer,  # type: ignore
    streaming=True,
    max_length=128,
    add_bos_token=True,
)

In [None]:
# # MATS
# hf_repo_id = "gpt2-small-hook-z-kk"
# sae_id = "blocks.9.hook_z"
# pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
# pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_attn_sae_wanda.pth'))

# Trained by me
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_sae_wanda.pth'))

get_reconstruction_loss(pruned_sae_trained_on_gpt2, gpt2, token_dataset, batch_size=32)

In [None]:
# MATS
hf_repo_id = "gpt2-small-hook-z-kk"
sae_id = "blocks.9.hook_z"
pruned_sae_trained_on_gpt2 = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device1))[0]
pruned_sae_trained_on_gpt2.load_state_dict(t.load('pruned/pruned_gpt2_attn_sae_wanda.pth'))

get_reconstruction_loss(pruned_sae_trained_on_gpt2, gpt2, token_dataset, batch_size=32)

In [None]:
gpt2_pruned = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device2)
gpt2_pruned.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))

# Trained by me
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
sae_trained_on_gpt2_pruned = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device2))[0]

get_reconstruction_loss(sae_trained_on_gpt2_pruned, gpt2_pruned, token_dataset, batch_size=32)

In [None]:
# MATS trained config
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-attn-sae-v1'
sae_trained_on_gpt2_pruned = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device2))[0]

get_reconstruction_loss(sae_trained_on_gpt2_pruned, gpt2_pruned, token_dataset, batch_size=32)

### Verify the reconstruction loss of the original pre-trained SAE on the compressed model's activations

In [None]:
sae = sae_lens.SAE.from_pretrained("gpt2-small-hook-z-kk","blocks.9.hook_z",device=device1,)[0]

In [None]:
model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device1)
get_reconstruction_loss(sae, model, token_dataset, batch_size=32)

In [None]:
pruned_model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device1)
pruned_model.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))
get_reconstruction_loss(sae, pruned_model, token_dataset, batch_size=32)

## Making sense of the results?

In [None]:
import pandas as pd
import os

In [None]:
c4 = pd.read_csv('logs/results_c4.csv')
c4_losses = pd.read_csv('logs/losses_c4.csv')

openwebtext = pd.read_csv('logs/results_openwebtext.csv')
openwebtext_losses = pd.read_csv('logs/losses_openwebtext.csv')

pile = pd.read_csv('logs/results_pile.csv')
pile_losses = pd.read_csv('logs/losses_pile.csv')

wiki = pd.read_csv('logs/results_wiki.csv')
wiki_losses = pd.read_csv('logs/losses_wiki.csv') 

##### Unused functions

In [None]:
# c4[(c4['Model'] == 'Pruned SAE') & (c4['Layer'] == 10) & (c4['Sparse Ratio'] >= 0.9)].sort_values(by='Validation Loss').head(20)
# openwebtext[(openwebtext['Model'] == 'Pruned SAE') & (openwebtext['Layer'] == 11) & (openwebtext['Sparse Ratio'] >= 0.9)].sort_values(by='Validation Loss').head(20)
# pile[(pile['Model'] == 'Pruned SAE') & (pile['Layer'] == 9) & (pile['Sparse Ratio'] >= 0.24)].sort_values(by='Validation Loss').head(20)
# wiki[(wiki['Model'] == 'Pruned SAE') & (wiki['Layer'] == 10) & (wiki['Sparse Ratio'] >= 0.9)].sort_values(by='Validation Loss').head(20)

# c4[c4['Model'] == 'SAE trained on pruned gpt2-small ']['Validation Loss']
# openwebtext[openwebtext['Model'] == 'SAE trained on pruned gpt2-small ']['Validation Loss']
# pile[pile['Model'] == 'SAE trained on pruned gpt2-small ']['Validation Loss']
# wiki[wiki['Model'] == 'SAE trained on pruned gpt2-small ']['Validation Loss']

# c4[c4['Model'] == 'Best sparse ratio pruned SAE']['Validation Loss']
# openwebtext[openwebtext['Model'] == 'Best sparse ratio pruned SAE']['Validation Loss']
# pile[pile['Model'] == 'Best sparse ratio pruned SAE']['Validation Loss']
# wiki[wiki['Model'] == 'Best sparse ratio pruned SAE']['Validation Loss']

# c4[c4['Model'] == 'Best sparse ratio pruned SAE']['Sparse Ratio']
# openwebtext[openwebtext['Model'] == 'Best sparse ratio pruned SAE']['Sparse Ratio']
# pile[pile['Model'] == 'Best sparse ratio pruned SAE']['Sparse Ratio']
# wiki[wiki['Model'] == 'Best sparse ratio pruned SAE']['Sparse Ratio']

# c4[c4['Model'] == 'Pretrained SAE']['Validation Loss']
# openwebtext[openwebtext['Model'] == 'Pretrained SAE']['Validation Loss']
# pile[pile['Model'] == 'Pretrained SAE']['Validation Loss']
# wiki[wiki['Model'] == 'Pretrained SAE']['Validation Loss']

# def plot(dataset):
#     # Create subplots for each layer
#     fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(15, 12), sharex=True, sharey=True)
#     axes = axes.flatten()
#     df = dataset[(dataset['Model'] == 'Pruned SAE')]

#     for i in range(12):
#         val1 = dataset[(dataset['Model'] == 'SAE trained on pruned gpt2-small ') & (dataset['Layer'] == i)]['Validation Loss'].values[0]
#         val2 = dataset[(dataset['Model'] == 'Pretrained SAE') & (dataset['Layer'] == i)]['Validation Loss'].values[0]
        
#         layer_df = df[df["Layer"] == i]
#         ax = axes[i]
#         ax.plot(layer_df["Sparse Ratio"], layer_df["Validation Loss"], label="Validation Loss")
#         ax.axhline(y=val1, color='r', linestyle='--', label="SAE trained on pruned gpt2-small ")
#         ax.axhline(y=val2, color='b', linestyle='--', label="Pretrained SAE")


#         ax.set_title(f"Layer {i}")
#         ax.legend()
#         ax.set_xlabel("Sparse Ratio")
#         ax.set_ylabel("Loss")

#     plt.tight_layout()
#     plt.show()


##### Interactive plots

In [None]:
import plotly.graph_objects as go
import plotly.subplots as sp

def plot(dataset, losses):
    # Create a 4x3 subplot grid
    fig = sp.make_subplots(rows=4, cols=3, subplot_titles=[f"Layer {i}" for i in range(12)],
                           shared_xaxes=True, shared_yaxes=True)

    df = dataset[dataset['Model'] == 'Pruned SAE']

    for i in range(12):
        val1 = dataset[(dataset['Model'] == 'SAE trained on pruned gpt2-small ') & (dataset['Layer'] == i)]['Validation Loss'].values[0]
        val2 = dataset[(dataset['Model'] == 'Pretrained SAE') & (dataset['Layer'] == i)]['Validation Loss'].values[0]

        val3 = losses[(losses['Model'] == 'SAE trained on pruned gpt2-small') & 
                      (losses['Layer'] == i) & 
                      (losses['Config'] == 'MATS') & 
                      (losses['Architecture'] == 'standard') & 
                      (losses['Epochs'] == '30K')]['Validation Loss'].values[0]
        
        val4 = losses[(losses['Model'] == 'SAE trained on pruned gpt2-small') & 
                      (losses['Layer'] == i) & 
                      (losses['Config'] == 'Custom') & 
                      (losses['Architecture'] == 'standard') & 
                      (losses['Epochs'] == '30K')]['Validation Loss'].values[0]
        
        # val5 = losses[(losses['Model'] == 'SAE trained on pruned gpt2-small') & 
        #               (losses['Layer'] == i) & 
        #               (losses['Config'] == 'Custom') & 
        #               (losses['Architecture'] == 'standard') & 
        #               (losses['Epochs'] == '50K')]['Validation Loss'].values[0]
        
        # assert val1 == val3
        # print(val1, val3)


        
        layer_df = df[df["Layer"] == i]
        row, col = (i // 3) + 1, (i % 3) + 1  # Determine subplot position
        
        # Add main line plot
        fig.add_trace(go.Scatter(x=layer_df["Sparse Ratio"], y=layer_df["Validation Loss"],
                                 mode='lines+markers', name="Pruned SAE",
                                 marker=dict(size=2, color="black"), line=dict(width=1, color="green"),
                                 showlegend=True if i == 0 else False), row=row, col=col)

        # Add horizontal reference lines
        fig.add_trace(go.Scatter(x=[layer_df["Sparse Ratio"].min(), layer_df["Sparse Ratio"].max()],
                                 y=[val1, val1], mode="lines", name="SAE trained on pruned (MATS)",
                                 line=dict(dash="dash", color="red"), showlegend=True if i == 0 else False),
                      row=row, col=col)

        fig.add_trace(go.Scatter(x=[layer_df["Sparse Ratio"].min(), layer_df["Sparse Ratio"].max()],
                                 y=[val4, val4], mode="lines", name="SAE trained on pruned (Custom)",
                                 line=dict(dash="dot", color="blue"), showlegend=True if i == 0 else False),
                      row=row, col=col)
        
        # fig.add_trace(go.Scatter(x=[layer_df["Sparse Ratio"].min(), layer_df["Sparse Ratio"].max()],
        #                          y=[val5, val5], mode="lines", name="SAE trained on pruned (Custom-50K)",
        #                          line=dict(dash="dashdot", color="purple"), showlegend=True if i == 0 else False),
        #               row=row, col=col)
        


        fig.add_trace(go.Scatter(x=[layer_df["Sparse Ratio"].min(), layer_df["Sparse Ratio"].max()],
                                 y=[val2, val2], mode="lines", name="Pretrained SAE",
                                 line=dict(dash="longdash", color="orange"), showlegend=True if i == 0 else False),
                      row=row, col=col)

    # Layout improvements
    fig.update_layout(height=800, width=1200,
                      title_text="Validation Loss vs Sparse Ratio across Layers",
                      title_x=0.5, showlegend=True, template="plotly_white")

    fig.update_xaxes(title_text="Sparse Ratio")
    fig.update_yaxes(title_text="Validation Loss")

    fig.show()


In [None]:
plot(c4, c4_losses)

In [None]:
plot(openwebtext, openwebtext_losses)

In [None]:
plot(pile, pile_losses)

In [None]:
plot(wiki, wiki_losses)