In [None]:
# To use: Move into root dir, run all.

In [None]:
from sae_lens.toolkit.pretrained_sae_loaders import (
    SAEConfigLoadOptions,
    get_sae_config,
)

from sae_bench.sae_bench_utils.sae_selection_utils import (
    select_saes_multiple_patterns,
)

In [None]:
import os

from huggingface_hub import snapshot_download

hf_repo_id = "neuronpedia/sae-evals"
local_dir = "./temp_sae_evals"
os.makedirs(local_dir, exist_ok=True)

snapshot_download(
    repo_id=hf_repo_id,
    local_dir=local_dir,
    repo_type="dataset",
)

In [None]:
import os

from huggingface_hub import snapshot_download

hf_repo_id = "neuronpedia/sae-evals"
local_dir = "./temp_sae_evals"
os.makedirs(local_dir, exist_ok=True)

snapshot_download(
    repo_id=hf_repo_id,
    local_dir=local_dir,
    repo_type="dataset",
    ignore_patterns=["*autointerp_with_generations*"],
)

In [None]:
sae_regex_patterns = [
    r"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109",
    r"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109",
    r"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109",
    r"(gemma-scope-2b-pt-res)",
    r"(gemma-scope-9b-pt-res)",
    r"(gemma-scope-2b-pt-res-canonical)",
    r"(gemma-scope-9b-pt-res-canonical)",
    r"sae_bench_pythia70m_sweep_gated_ctx128_0730",
    r"sae_bench_pythia70m_sweep_panneal_ctx128_0730",
    r"sae_bench_pythia70m_sweep_standard_ctx128_0712",
    r"sae_bench_pythia70m_sweep_topk_ctx128_0730",
]

# Include checkpoints (not relevant to Gemma-Scope)
sae_block_pattern = [".*"] * len(sae_regex_patterns)

# Exclude checkpoints
# sae_block_pattern = [
#     # rf".*blocks\.{layer}(?!.*step).*",
#     # rf".*blocks\.{layer}(?!.*step).*",
#     rf".*blocks\.{layer}(?!.*step).*",
#     rf".*blocks\.{layer}(?!.*step).*",
#     # rf".*layer_({layer}).*(16k).*", # For Gemma-Scope
# ]


assert len(sae_regex_patterns) == len(sae_block_pattern)

selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)

In [None]:
folders = [f"{local_dir}/core"]
import json
import os

from tqdm import tqdm

total = 0
total_updated = 0

for folder in tqdm(folders):
    for sae_release, sae_id in tqdm(selected_saes):
        sae_id = sae_id
        total += 1

        if "blocks" not in sae_id:
            continue

        old_filename = f"{folder}/{sae_release}/{sae_release}_{sae_id.replace('.', '_')}_128_Skylion007_openwebtext.json"
        new_filename = f"{folder}/{sae_release}/{sae_release}_{sae_id}_128_Skylion007_openwebtext.json"

        # print(old_filename)
        # print(new_filename)

        if not os.path.exists(old_filename):
            continue

        sae_cfg = get_sae_config(
            sae_release,
            sae_id,
            options=SAEConfigLoadOptions(),
        )

        # print(type(sae_cfg))
        # print(sae_cfg)
        # break

        with open(old_filename) as f:
            eval_results = json.load(f)

        eval_results["sae_cfg_dict"] = sae_cfg

        with open(new_filename, "w") as f:
            json.dump(eval_results, f, indent=4)

        # print(f"Updated {new_filename}")

        # delete old file
        os.remove(old_filename)

        # break
        total_updated += 1

print(total, total_updated)

In [None]:
folders = [
    f"{local_dir}/absorption",
    f"{local_dir}/autointerp",
    f"{local_dir}/scr",
    f"{local_dir}/sparse_probing",
    f"{local_dir}/tpp",
    f"{local_dir}/unlearning",
]
import json
import os

from tqdm import tqdm

total = 0
total_updated = 0

for folder in tqdm(folders):
    for sae_release, sae_id in tqdm(selected_saes):
        sae_id = sae_id
        total += 1
        old_filename = f"{folder}/{sae_release}/{sae_release}_{sae_id.replace('/', '_')}_eval_results.json"

        # print(filename)

        if not os.path.exists(old_filename):
            continue

        sae_cfg = get_sae_config(
            sae_release,
            sae_id,
            options=SAEConfigLoadOptions(),
        )

        # print(type(sae_cfg))
        # print(sae_cfg)
        # break

        with open(old_filename) as f:
            eval_results = json.load(f)

        eval_results["sae_cfg_dict"] = sae_cfg

        with open(old_filename, "w") as f:
            json.dump(eval_results, f, indent=4)

        # print(f"Updated {filename}")

        # break
        total_updated += 1

print(total, total_updated)

In [None]:
import os


def inspect_local_directory(directory):
    if not os.path.exists(directory):
        print(f"Directory does not exist: {directory}")
        return

    if not os.path.isdir(directory):
        print(f"Path is not a directory: {directory}")
        return

    print(f"Inspecting files in directory: {directory}")
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            try:
                # Attempt to open the file in binary mode to check for corruption
                with open(file_path, "rb") as f:
                    f.read()  # Read the first 1 KB of the file
                # print(f"File is accessible: {file_path}")
            except Exception as e:
                print(f"Error accessing file {file_path}: {e}")


# Replace with the actual local directory path
# local_dir = "your_directory_path_here"
inspect_local_directory(local_dir)

In [None]:
# from huggingface_hub import HfApi
# api = HfApi()

# api.upload_folder(
#     folder_path="temp_sae_evals_2",
#     path_in_repo="",
#     repo_id="adamkarvonen/sae_bench_results",
#     repo_type="dataset",
#     # allow_patterns="*eval_results.json"
#     ignore_patterns=[".DS_Store"]
# )

In [None]:
from huggingface_hub import HfApi

api = HfApi()

test_dir = "eval_results"

api.upload_folder(
    folder_path=test_dir,
    path_in_repo="",
    repo_id="adamkarvonen/sae_bench_results",
    repo_type="dataset",
    # allow_patterns="*eval_results.json"
    ignore_patterns=[".DS_Store", ".git", ".git/**"],
)

In [None]:
# purpose: all files are dumped into the same results folder.
# separate by sae release name

import os
import shutil


def organize_files():
    # Define the two prefixes we're looking for
    prefixes = [
        "sae_bench_gemma-2-2b_topk_width-2pow16_date-1109",
        "sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109",
    ]

    # Create folders if they don't exist
    for prefix in prefixes:
        if not os.path.exists(prefix):
            os.makedirs(prefix)

    # Get all json files in current directory
    files = [f for f in os.listdir(".") if f.endswith(".json")]

    # Move files to appropriate folders
    for file in files:
        for prefix in prefixes:
            if file.startswith(prefix):
                shutil.move(file, os.path.join(prefix, file))
                print(f"Moved {file} to {prefix}/")
                break


if __name__ == "__main__":
    organize_files()

In [None]:
# purpose: remove the llm generations, which is over 99% of the file size

import json
import os


def process_json_files(directory):
    """Recursively process all JSON files in directory and its subdirectories"""
    count = 0
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".json"):
                filepath = os.path.join(root, file)
                try:
                    # Read the JSON file
                    with open(filepath) as f:
                        data = json.load(f)

                    # Remove the key if it exists
                    if "eval_result_unstructured" in data:
                        del data["eval_result_unstructured"]
                        count += 1

                    # Write back the modified data
                    with open(filepath, "w") as f:
                        json.dump(data, f)

                    print(f"Processed: {filepath}")

                except Exception as e:
                    print(f"Error processing {filepath}: {str(e)}")

    return count


# Process files starting from the current directory
starting_dir = "."
files_modified = process_json_files(starting_dir)
print(f"\nCompleted! Modified {files_modified} files.")

In [None]:
# This cell is to add "training_tokens" to the sae_cfg_dict for all sae bench sae results

import json
import os
import re

from tqdm import tqdm

from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns

folders = [
    f"{local_dir}/absorption",
    f"{local_dir}/autointerp",
    f"{local_dir}/core",
    f"{local_dir}/scr",
    f"{local_dir}/sparse_probing",
    f"{local_dir}/tpp",
    f"{local_dir}/unlearning",
]

sae_regex_patterns = [
    r"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109",
    r"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109",
    r"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109",
    r"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109",
    r"sae_bench_pythia70m_sweep_gated_ctx128_0730",
    r"sae_bench_pythia70m_sweep_panneal_ctx128_0730",
    r"sae_bench_pythia70m_sweep_standard_ctx128_0712",
    r"sae_bench_pythia70m_sweep_topk_ctx128_0730",
]

# Include checkpoints (not relevant to Gemma-Scope)
sae_block_pattern = [".*"] * len(sae_regex_patterns)

# Exclude checkpoints
# sae_block_pattern = [
#     # rf".*blocks\.{layer}(?!.*step).*",
#     # rf".*blocks\.{layer}(?!.*step).*",
#     rf".*blocks\.{layer}(?!.*step).*",
#     rf".*blocks\.{layer}(?!.*step).*",
#     # rf".*layer_({layer}).*(16k).*", # For Gemma-Scope
# ]


assert len(sae_regex_patterns) == len(sae_block_pattern)

selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)

local_dir = "."


def get_sae_bench_train_tokens(sae_release: str, sae_id: str) -> int:
    """This is for SAE Bench internal use. The SAE cfg does not contain the number of training tokens, so we need to hardcode it."""

    if "sae_bench" not in sae_release:
        raise ValueError("This function is only for SAE Bench releases")

    if "pythia" in sae_release:
        batch_size = 4096
    else:
        batch_size = 2048

    if "step" not in sae_id:
        if "pythia" in sae_release:
            steps = 48828
        elif "2pow14" in sae_release:
            steps = 146484
        elif "2pow12" or "2pow16" in sae_release:
            steps = 97656
        else:
            raise ValueError(f"sae release {sae_release} not recognized")

        return steps * batch_size
    else:
        match = re.search(r"step_(\d+)", sae_id)
        if match:
            step = int(match.group(1))
            return step * batch_size
        else:
            raise ValueError("No step match found")


total = 0
total_updated = 0

for folder in tqdm(folders):
    for sae_release, sae_id in tqdm(selected_saes):
        sae_id = sae_id
        total += 1
        old_filename = f"{folder}/{sae_release}/{sae_release}_{sae_id.replace('/', '_')}_eval_results.json"

        # print(filename)

        if not os.path.exists(old_filename):
            continue

        # print(type(sae_cfg))
        # print(sae_cfg)
        # break

        with open(old_filename) as f:
            eval_results = json.load(f)

        eval_results["sae_cfg_dict"]["training_tokens"] = get_sae_bench_train_tokens(
            sae_release, sae_id
        )

        with open(old_filename, "w") as f:
            json.dump(eval_results, f, indent=4)

        # print(f"Updated {filename}")

        # break
        total_updated += 1

print(total, total_updated)