In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import torch
from huggingface_hub import snapshot_download

import sae_bench.custom_saes.batch_topk_sae as batch_topk_sae
import sae_bench.custom_saes.gated_sae as gated_sae
import sae_bench.custom_saes.jumprelu_sae as jumprelu_sae
import sae_bench.custom_saes.relu_sae as relu_sae
import sae_bench.custom_saes.topk_sae as topk_sae

TRAINER_LOADERS = {
    "StandardTrainer": relu_sae.load_dictionary_learning_relu_sae,
    "StandardTrainerAprilUpdate": relu_sae.load_dictionary_learning_relu_sae,
    "PAnnealTrainer": relu_sae.load_dictionary_learning_relu_sae,
    "TopKTrainer": topk_sae.load_dictionary_learning_topk_sae,
    "JumpReluTrainer": jumprelu_sae.load_dictionary_learning_jump_relu_sae,
    "BatchTopKTrainer": batch_topk_sae.load_dictionary_learning_batch_topk_sae,
    "GatedSAETrainer": gated_sae.load_dictionary_learning_gated_sae,
}


def get_all_hf_repo_autoencoders(
    repo_id: str, download_location: str = "downloaded_saes"
) -> list[str]:
    download_location = os.path.join(download_location, repo_id.replace("/", "_"))
    config_dir = snapshot_download(
        repo_id,
        allow_patterns=["*config.json"],
        local_dir=download_location,
        force_download=False,
    )

    configs = []

    for root, _, files in os.walk(config_dir):
        for file in files:
            if file == "config.json":
                configs.append(os.path.join(root, file))

    repo_locations = []

    for config in configs:
        repo_location = config.split(f"{download_location}/")[1].split("config.json")[0]
        repo_locations.append(repo_location)

    return repo_locations


repo_id = "adamkarvonen/sae_test"
locations = get_all_hf_repo_autoencoders(repo_id)

print(locations)

In [None]:
import sae_bench.custom_saes.base_sae as base_sae

layer = 3

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32

model_name = "EleutherAI/pythia-70m-deduped"


def load_dictionary_learning_sae(
    repo_id: str, location: str, layer: int, model_name, device: str, dtype: torch.dtype
) -> base_sae.BaseSAE:
    for key, loader in TRAINER_LOADERS.items():
        if key in location:
            sae = loader(
                repo_id=repo_id,
                filename=location,
                layer=layer,
                model_name=model_name,
                device=device,
                dtype=dtype,
            )
            return sae

    raise ValueError(f"Could not find a loader for {location}")


sae = load_dictionary_learning_sae(
    repo_id, f"{locations[0]}ae.pt", layer, model_name, device, dtype
)
sae.test_sae(model_name)