In [5]:
import os
from pathlib import Path
from typing import List, Dict, Tuple
from collections import defaultdict


def get_model_configs() -> List[Dict[str, any]]:
    """Define all model configurations based on the complete grid."""
    return [
        {
            "model_name": "google/gemma-2b",
            "model_short": "gemma-2b",
            "layers": ["6", "12", "17"],
            "widths": ["16k"],  # 2^14
            "datasets": [
                ("Anthropic/election_questions", "test"),
                ("textdetox/multilingual_toxicity_dataset", "en"),
                ("AIM-Harvard/reject_prompts", "train"),
                ("jackhhao/jailbreak-classification", "test"),
                ("sorry-bench/sorry-bench-202406", "train"),
            ],
        },
        {
            "model_name": "google/gemma-2-2b",
            "model_short": "gemma-2-2b",
            "layers": ["5", "12", "19"],
            "widths": ["16k", "65k", "1m"],  # 2^14, 2^16, 2^20
            "datasets": [
                ("Anthropic/election_questions", "test"),
                ("textdetox/multilingual_toxicity_dataset", "en"),
                ("AIM-Harvard/reject_prompts", "train"),
                ("jackhhao/jailbreak-classification", "test"),
                ("sorry-bench/sorry-bench-202406", "train"),
            ],
        },
        {
            "model_name": "google/gemma-2-9b",
            "model_short": "gemma-2-9b",
            "layers": ["9", "20", "31"],
            "widths": ["16k", "131k", "1m"],  # 2^14, 2^17, 2^20
            "datasets": [
                ("Anthropic/election_questions", "test"),
                ("textdetox/multilingual_toxicity_dataset", "en"),
                ("AIM-Harvard/reject_prompts", "train"),
                ("jackhhao/jailbreak-classification", "test"),
                ("sorry-bench/sorry-bench-202406", "train"),
            ],
        },
        {
            "model_name": "google/gemma-2-9b-it",
            "model_short": "gemma-2-9b-it",
            "layers": ["9", "20", "31"],
            "widths": ["16k", "131k", "1m"],  # 2^14, 2^17, 2^20
            "datasets": [
                ("Anthropic/election_questions", "test"),
                ("textdetox/multilingual_toxicity_dataset", "en"),
                ("AIM-Harvard/reject_prompts", "train"),
                ("jackhhao/jailbreak-classification", "test"),
                ("sorry-bench/sorry-bench-202406", "train"),
            ],
        },
    ]


def format_missing_files(missing_files: List[str]) -> str:
    """Format missing files in a structured hierarchy."""
    structure = defaultdict(
        lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    )

    for path in missing_files:
        parts = Path(path).parts

        # Extract components from path
        for part in parts:
            if part.startswith("gemma-"):
                model = part
            elif part.startswith("width_"):
                width = part.replace("width_", "")
            elif any(
                dataset in part
                for dataset in [
                    "election_questions",
                    "multilingual_toxicity_dataset",
                    "reject_prompts",
                    "jailbreak-classification",
                    "sorry-bench",
                ]
            ):
                dataset = part
            elif part in ["train", "test", "en"]:
                split = part
            elif part.startswith("layer_"):
                layer = part.replace("layer_", "")

        # Organize into nested structure
        structure[model][width][f"{dataset}/{split}"][layer].append(path)

    # Format the output
    output = []
    for model in sorted(structure.keys()):
        output.append(f"Model: {model}")
        for width in sorted(structure[model].keys()):
            output.append(f"  Width: {width}")
            for dataset in sorted(structure[model][width].keys()):
                output.append(f"    Task: {dataset}")
                for layer in sorted(structure[model][width][dataset].keys()):
                    output.append(f"      Layer: {layer}")

    return "\n".join(output)


def verify_outputs(base_dir: str = "./outputs") -> Tuple[List[str], List[str]]:
    """
    Verify that all expected output files exist.

    Args:
        base_dir: Base directory to look for outputs

    Returns:
        Tuple containing lists of (found_files, missing_files)
    """
    found_files = []
    missing_files = []
    configs = get_model_configs()

    # Convert base_dir to absolute path
    base_dir = os.path.abspath(base_dir)
    print(f"Checking files in: {base_dir}")
    print("=" * 80)

    total_combinations = 0

    for config in configs:
        for width in config["widths"]:
            for layer in config["layers"]:
                for dataset, split in config["datasets"]:
                    total_combinations += 1
                    # Convert dataset path to filesystem-friendly format
                    dataset_path = dataset.replace("/", "_")

                    # Construct the expected path
                    expected_path = os.path.join(
                        base_dir,
                        config["model_short"],
                        f"width_{width}",
                        config["model_name"].replace("/", "_"),
                        dataset_path,
                        split,
                        f"layer_{layer}",
                        width,
                        "sample_0.npz",
                    )

                    if os.path.exists(expected_path):
                        found_files.append(expected_path)
                    else:
                        missing_files.append(expected_path)

    # Print summary
    print("\nSummary:")
    print(f"Total configurations to check: {total_combinations}")
    print(f"Found: {len(found_files)}")
    print(f"Missing: {len(missing_files)}")

    # Print missing files in structured format if any
    if missing_files:
        print("\nMissing configurations:")
        print(format_missing_files(missing_files))

    return found_files, missing_files

In [6]:
found, missing = verify_outputs("./outputs")

Checking files in: /home/jg1223/sae_llava/src/outputs

Summary:
Total configurations to check: 150
Found: 91
Missing: 59

Missing configurations:
Model: gemma-2-2b
  Width: 1m
    Task: AIM-Harvard_reject_prompts/train
      Layer: 12
      Layer: 19
      Layer: 5
    Task: Anthropic_election_questions/test
      Layer: 12
      Layer: 19
      Layer: 5
    Task: jackhhao_jailbreak-classification/test
      Layer: 12
      Layer: 19
      Layer: 5
    Task: sorry-bench_sorry-bench-202406/train
      Layer: 12
      Layer: 19
      Layer: 5
    Task: textdetox_multilingual_toxicity_dataset/en
      Layer: 12
      Layer: 19
      Layer: 5
Model: gemma-2-9b
  Width: 131k
    Task: jackhhao_jailbreak-classification/test
      Layer: 31
    Task: sorry-bench_sorry-bench-202406/train
      Layer: 20
      Layer: 31
  Width: 1m
    Task: AIM-Harvard_reject_prompts/train
      Layer: 20
      Layer: 31
      Layer: 9
    Task: Anthropic_election_questions/test
      Layer: 20
      Layer: 31