In [2]:
import os
import sys

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

In [3]:
from typing import List

import torch
from models.autoencoder import AutoEncoder, AutoEncoderConfig
from transformer_lens import HookedTransformer
from transformer_lens.utils import download_file_from_hf

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
VERSION_DICT = {"run1": 25, "run2": 47}


def load_autoencoder_from_huggingface(versions: List[str] = ["run1", "run2"]):
    state_dict = {}

    for version in versions:
        version_id = VERSION_DICT[version]
        # Load the data from huggingface (both metadata and state dict)
        sae_data: dict = download_file_from_hf(
            "NeelNanda/sparse_autoencoder", f"{version_id}_cfg.json"
        )
        new_state_dict: dict = download_file_from_hf(
            "NeelNanda/sparse_autoencoder", f"{version_id}.pt", force_is_torch=True
        )
        # Add new state dict to the existing one
        for k, v in new_state_dict.items():
            state_dict[k] = torch.stack([state_dict[k], v]) if k in state_dict else v

    # Get data about the model dimensions, and use that to initialize our model (with 2 instances)
    d_mlp = sae_data["d_mlp"]
    dict_mult = sae_data["dict_mult"]
    n_hidden_ae = d_mlp * dict_mult

    cfg = AutoEncoderConfig(
        n_instances=2,
        n_input_ae=d_mlp,
        n_hidden_ae=n_hidden_ae,
    )

    # Initialize our model, and load in state dict
    autoencoder = AutoEncoder(cfg)
    autoencoder.load_state_dict(state_dict)

    return autoencoder


autoencoder = load_autoencoder_from_huggingface()
autoencoder

AutoEncoder()