# ART-Organized SAE Feature Taxonomy

Cluster pre-trained SAE decoder weight vectors from **Pythia** models using Adaptive Resonance Theory (ART) modules to build hierarchical feature taxonomies.

**Pipeline:** SAE W_dec → optional PCA → ART clustering → hierarchical taxonomy via SMART

**Key idea:** ART's vigilance parameter provides a single, principled knob for controlling cluster granularity — something SAEs lack. SMART builds a hierarchy by stacking multiple vigilance levels.

**Supported models:** Pythia-70M (512-dim, 32k features), Pythia-410M (1024-dim, 65k features)

In [1]:
# --- Setup ---
import sys, os

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # Clone or pull the repo, then cd into it
    if os.path.exists("art_utils.py"):
        # Already inside the repo (cell re-run)
        !git pull
    elif os.path.exists("ART-SAE-taxonomy"):
        os.chdir("ART-SAE-taxonomy")
        !git pull
    else:
        !git clone https://github.com/syre-ai/ART-SAE-taxonomy.git
        os.chdir("ART-SAE-taxonomy")
    !pip install -q artlib==0.1.7 eai-sparsify transformers torch umap-learn plotly tqdm pandas

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

# Force-reload local modules so git pull changes take effect
import importlib
import gpu_fuzzy_art, art_utils
importlib.reload(gpu_fuzzy_art)
importlib.reload(art_utils)

from art_utils import (
    create_art_module,
    create_smart_model,
    extract_hierarchy_labels,
    list_available_modules,
)

print("Setup complete.")

Setup complete.


In [8]:
# ============================================================
# CONFIGURATION — edit this cell to change experiments
# ============================================================

# --- Model ---
LM_MODEL = "EleutherAI/pythia-410m"     # LM for logit lens interpretation
SAE_MODEL = "EleutherAI/sae-pythia-410m-65k"  # SAE decoder weights to cluster
LAYER_IDX = 12             # Middle layer (410M has 24 layers, 70M has 6)

# --- Clustering ---
MODULE_NAME = "GPUFuzzyART"  # GPU-accelerated FuzzyART (see list_available_modules())
PARAM_OVERRIDES = {"rho": 0.15}       # e.g. {"rho": 0.8}
PCA_DIMS = None            # Set to int (e.g. 50, 100, 200) to reduce dims; None = use raw dims
QUICK_TEST = True          # Subsample to 1000 features for fast iteration
RANDOM_SEED = 42

In [3]:
# --- Checkpoint Restore (skip Cells 4-6 if checkpoint exists) ---
# Set USE_CHECKPOINT = True to load saved pipeline state instead of recomputing.
# This saves ~15 minutes by skipping SAE loading, logit lens, and junk filtering.

import os, torch

USE_CHECKPOINT = False  # Set True to restore from checkpoint
CHECKPOINT_PATH = "checkpoints/pipeline_state.pt"

if USE_CHECKPOINT and os.path.exists(CHECKPOINT_PATH):
    print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
    ckpt = torch.load(CHECKPOINT_PATH, weights_only=False)

    # Verify config matches
    ckpt_cfg = ckpt["config"]
    assert ckpt_cfg["lm_model"] == LM_MODEL, \
        f"Checkpoint LM mismatch: {ckpt_cfg['lm_model']} vs {LM_MODEL}"
    assert ckpt_cfg["sae_model"] == SAE_MODEL, \
        f"Checkpoint SAE mismatch: {ckpt_cfg['sae_model']} vs {SAE_MODEL}"

    # Restore tensors
    W_dec_full = ckpt["W_dec_full"]
    top5_ids = ckpt["top5_ids"]
    unembed_weights = ckpt["unembed_weights"]
    final_ln_weight = ckpt["final_ln_weight"]
    final_ln_bias = ckpt["final_ln_bias"]
    valid_indices = ckpt["valid_indices"]

    # Rebuild final_ln as a simple module for logit lens
    from transformers import AutoTokenizer, AutoModelForCausalLM
    tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
    # Create a minimal layer norm with saved weights
    import torch.nn as nn
    final_ln = nn.LayerNorm(final_ln_weight.shape[0])
    final_ln.weight.data = final_ln_weight
    final_ln.bias.data = final_ln_bias
    final_ln.eval()

    # Rebuild filtered/subsampled data
    import numpy as np
    W_dec_raw = W_dec_full.numpy()
    kept_indices = valid_indices.numpy()
    W_dec_filtered = W_dec_raw[kept_indices]
    rng = np.random.default_rng(RANDOM_SEED)
    if QUICK_TEST:
        n_sample = min(1000, W_dec_filtered.shape[0])
        sub_idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
        idx_original = kept_indices[sub_idx]
        W_dec = W_dec_filtered[sub_idx]
    else:
        perm = rng.permutation(W_dec_filtered.shape[0])
        idx_original = kept_indices[perm]
        W_dec = W_dec_filtered[perm]
    idx = idx_original

    # Run preprocessing (PCA) inline
    if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
        from sklearn.decomposition import PCA
        pca = PCA(n_components=PCA_DIMS, random_state=RANDOM_SEED)
        X_reduced = pca.fit_transform(W_dec)
    else:
        X_reduced = W_dec.copy()
    data_dim = X_reduced.shape[1]

    print(f"Checkpoint restored: {W_dec_full.shape[0]} features, "
          f"{len(kept_indices)} valid, {X_reduced.shape[0]} selected, "
          f"{data_dim}-dim")
    print(f"Config: {ckpt_cfg}")
    print(">>> Skip Cells 4-7 and proceed to Cell 8 (Create ART module)")
else:
    if USE_CHECKPOINT:
        print(f"Checkpoint not found at {CHECKPOINT_PATH} — run full pipeline (Cells 4-7)")
    else:
        print("Checkpoint restore disabled (USE_CHECKPOINT=False) — run full pipeline")

Checkpoint restore disabled (USE_CHECKPOINT=False) — run full pipeline


In [4]:
# --- Load SAE decoder weights ---
from sparsify import Sae

saes = Sae.load_many(SAE_MODEL)

# Detect key format: "layers.N" (70M) vs "layers.N.mlp" (410M)
layer_key = f"layers.{LAYER_IDX}"
if layer_key not in saes:
    layer_key = f"layers.{LAYER_IDX}.mlp"
W_dec_raw = saes[layer_key].W_dec.detach().numpy()
print(f"Loaded W_dec from {SAE_MODEL}")
print(f"  Layer: {layer_key}, shape: {W_dec_raw.shape} (features × hidden_dim)")
print(f"  Available keys: {sorted(saes.keys())}")

Fetching 49 files:   0%|          | 0/49 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Loaded W_dec from EleutherAI/sae-pythia-410m-65k
  Layer: layers.12.mlp, shape: (65536, 1024) (features × hidden_dim)
  Available keys: ['layers.0.mlp', 'layers.1.mlp', 'layers.10.mlp', 'layers.11.mlp', 'layers.12.mlp', 'layers.13.mlp', 'layers.14.mlp', 'layers.15.mlp', 'layers.16.mlp', 'layers.17.mlp', 'layers.18.mlp', 'layers.19.mlp', 'layers.2.mlp', 'layers.20.mlp', 'layers.21.mlp', 'layers.22.mlp', 'layers.23.mlp', 'layers.3.mlp', 'layers.4.mlp', 'layers.5.mlp', 'layers.6.mlp', 'layers.7.mlp', 'layers.8.mlp', 'layers.9.mlp']


In [5]:
# --- Filter junk features + subsample ---
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load LM (reused by Cell 11 for interpretation)
tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
lm = AutoModelForCausalLM.from_pretrained(LM_MODEL)
lm.eval()
lm.float()

# Pythia models have UNTIED embeddings — use embed_out + final_layer_norm
unembed_weights = lm.embed_out.weight.detach()  # (vocab_size, hidden_dim)
final_ln = lm.gpt_neox.final_layer_norm
print(f"LM: {LM_MODEL}, vocab={unembed_weights.shape[0]}, hidden={unembed_weights.shape[1]}")

# --- Batch logit lens to classify all features (chunked to avoid OOM) ---
W_dec_full = torch.from_numpy(W_dec_raw).float()
n_features = W_dec_full.shape[0]
top5_ids = torch.zeros(n_features, 5, dtype=torch.long)

LOGIT_BATCH = 4096
with torch.no_grad():
    for start in range(0, n_features, LOGIT_BATCH):
        end = min(start + LOGIT_BATCH, n_features)
        normed = final_ln(W_dec_full[start:end])
        logits = normed @ unembed_weights.T
        top5_ids[start:end] = logits.topk(5, dim=1).indices
        del normed, logits  # free memory between chunks
print(f"Logit lens: computed top-5 tokens for {n_features} features")

# --- Junk token classifier (minimal: only truly uninformative tokens) ---
def is_junk_token(token_id):
    decoded = tokenizer.decode([token_id])
    if '\ufffd' in decoded:        return True   # byte-fallback (broken encoding)
    if '<|' in decoded:            return True   # special tokens (<|endoftext|> etc.)
    return False

# Classify each feature: junk if >=3 of top-5 tokens are junk
junk_counts = torch.zeros(n_features, dtype=torch.long)
for i in range(n_features):
    junk_counts[i] = sum(is_junk_token(tid.item()) for tid in top5_ids[i])

is_junk_feature = junk_counts >= 3  # >=60% junk
keep_mask = ~is_junk_feature

n_total = W_dec_raw.shape[0]
n_kept = keep_mask.sum().item()
n_filtered = n_total - n_kept
print(f"Junk filter: {n_total} total → {n_kept} kept, {n_filtered} filtered ({n_filtered/n_total:.1%} junk)")

# Distribution of junk counts
for j in range(6):
    nj = (junk_counts == j).sum().item()
    print(f"  {j}/5 junk tokens: {nj} features")

# --- Apply filter, then subsample/shuffle ---
kept_indices = torch.where(keep_mask)[0].numpy()
W_dec_filtered = W_dec_raw[kept_indices]

rng = np.random.default_rng(RANDOM_SEED)
if QUICK_TEST:
    n_sample = min(1000, W_dec_filtered.shape[0])
    idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
    # Map back to original feature indices
    idx_original = kept_indices[idx]
    W_dec = W_dec_filtered[idx]
    print(f"\nQuick test mode: subsampled to {W_dec.shape[0]} features (from {W_dec_filtered.shape[0]} non-junk)")
else:
    perm = rng.permutation(W_dec_filtered.shape[0])
    idx_original = kept_indices[perm]
    W_dec = W_dec_filtered[perm]
    print(f"\nShuffled {W_dec.shape[0]} non-junk features")

# idx tracks the original feature indices (for logit lens lookups in Cell 11)
idx = idx_original

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]

LM: EleutherAI/pythia-410m, vocab=50304, hidden=1024


Logit lens: computed top-5 tokens for 65536 features


Junk filter: 65536 total → 65534 kept, 2 filtered (0.0% junk)
  0/5 junk tokens: 63560 features
  1/5 junk tokens: 1935 features
  2/5 junk tokens: 39 features
  3/5 junk tokens: 2 features
  4/5 junk tokens: 0 features
  5/5 junk tokens: 0 features

Quick test mode: subsampled to 1000 features (from 65534 non-junk)


In [None]:
# --- Preprocessing: optional PCA ---

norms = np.linalg.norm(W_dec, axis=1)
print(f"W_dec L2 norms — min: {norms.min():.4f}, max: {norms.max():.4f}, "
      f"mean: {norms.mean():.4f}, std: {norms.std():.6f}")

if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
    # Fit full PCA to report variance landscape, then truncate
    max_components = min(W_dec.shape[0], W_dec.shape[1])
    pca_full = PCA(n_components=min(max_components, max(PCA_DIMS, 200)), random_state=RANDOM_SEED)
    pca_full.fit(W_dec)
    cumvar = np.cumsum(pca_full.explained_variance_ratio_)

    print(f"Variance landscape ({len(cumvar)} components computed):")
    for thresh in [0.50, 0.80, 0.90, 0.95, 0.99]:
        n_dims = np.searchsorted(cumvar, thresh) + 1
        if n_dims <= len(cumvar):
            print(f"  {thresh:.0%} variance: {n_dims} dims")
        else:
            print(f"  {thresh:.0%} variance: >{len(cumvar)} dims (max computed: {cumvar[-1]:.1%})")

    # Use the requested number of dims
    X_reduced = pca_full.transform(W_dec)[:, :PCA_DIMS]
    explained = cumvar[PCA_DIMS - 1]
    print(f"\nPCA: {W_dec.shape[1]} → {PCA_DIMS} dims ({explained:.1%} variance retained)")
else:
    X_reduced = W_dec.copy()
    print(f"No PCA applied. Using {X_reduced.shape[1]} dims.")

data_dim = X_reduced.shape[1]
print(f"Final data shape: {X_reduced.shape}")

In [7]:
# --- Checkpoint Save ---
# Saves pipeline state so future runs can skip Cells 4-6 (~15 min → ~3 sec).
# Run this after Cells 4-6 complete successfully.

import torch, os

os.makedirs("checkpoints", exist_ok=True)
CHECKPOINT_PATH = "checkpoints/pipeline_state.pt"

checkpoint = {
    "W_dec_full": W_dec_full,
    "top5_ids": top5_ids,
    "unembed_weights": unembed_weights,
    "final_ln_weight": final_ln.weight.data.clone(),
    "final_ln_bias": final_ln.bias.data.clone(),
    "valid_indices": torch.from_numpy(kept_indices),
    "config": {
        "lm_model": LM_MODEL,
        "sae_model": SAE_MODEL,
        "layer_idx": LAYER_IDX,
        "n_features": W_dec_full.shape[0],
        "hidden_dim": W_dec_full.shape[1],
    }
}

torch.save(checkpoint, CHECKPOINT_PATH)
file_size = os.path.getsize(CHECKPOINT_PATH) / (1024 * 1024)
print(f"Checkpoint saved to {CHECKPOINT_PATH} ({file_size:.1f} MB)")
print(f"  W_dec_full: {W_dec_full.shape}")
print(f"  top5_ids: {top5_ids.shape}")
print(f"  valid_indices: {len(kept_indices)} features")
print(f"To restore: set USE_CHECKPOINT=True in Cell 2, then run Cell 3")

Checkpoint saved to checkpoints/pipeline_state.pt (455.5 MB)
  W_dec_full: torch.Size([65536, 1024])
  top5_ids: torch.Size([65536, 5])
  valid_indices: 65534 features
To restore: set USE_CHECKPOINT=True in Cell 2, then run Cell 3


In [9]:
# --- Create ART module ---
list_available_modules()
print()

model = create_art_module(MODULE_NAME, dim=data_dim, overrides=PARAM_OVERRIDES)
print(f"\nCreated {MODULE_NAME} for {data_dim}-dim data")
print(f"Parameters: {model.params}" if hasattr(model, 'params') else "")

Module             Complement Codes   Default Params
--------------------------------------------------------------------------------
FuzzyART           Yes                {'rho': 0.7, 'alpha': 0.01, 'beta': 1.0}
GaussianART        No                 {'rho': 0.5, 'alpha': 1e-10} (+dim-dependent)
HypersphereART     No                 {'rho': 0.7, 'alpha': 0.01, 'beta': 1.0, 'r_hat': 1.0}
BayesianART        No                 {'rho': 0.01} (+dim-dependent)
EllipsoidART       No                 {'rho': 0.7, 'alpha': 1e-07, 'beta': 1.0, 'mu': 0.8, 'r_hat': 1.0}
GPUFuzzyART        No                 {'rho': 0.7, 'alpha': 0.01, 'beta': 1.0}


Created GPUFuzzyART for 1024-dim data
Parameters: {'rho': 0.15, 'alpha': 0.01, 'beta': 1.0}


In [10]:
# --- Flat clustering ---
X_prepared = model.prepare_data(X_reduced)
print(f"Data after prepare_data: shape {X_prepared.shape}")

model.fit(X_prepared, verbose=True)

labels = model.labels_
n_clusters = model.n_clusters
print(f"\nClusters found: {n_clusters}")
print(f"Samples: {len(labels)}")

# Cluster size stats
unique, counts = np.unique(labels, return_counts=True)
print(f"Cluster sizes — min: {counts.min()}, max: {counts.max()}, "
      f"median: {np.median(counts):.0f}, mean: {counts.mean():.1f}")

Data after prepare_data: shape (1000, 2048)


Clustering:   0%|          | 0/1000 [00:00<?, ?it/s]


Clusters found: 6
Samples: 1000
Cluster sizes — min: 20, max: 202, median: 196, mean: 166.7


In [None]:
# --- UMAP visualization of flat clusters ---
import umap

reducer = umap.UMAP(n_components=2, random_state=RANDOM_SEED, n_neighbors=15)
embedding = reducer.fit_transform(X_reduced)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    embedding[:, 0], embedding[:, 1],
    c=labels, cmap="tab20", s=5, alpha=0.6
)
ax.set_title(f"{MODULE_NAME} Clusters (n={n_clusters}) — UMAP projection")
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
plt.colorbar(scatter, ax=ax, label="Cluster ID")
plt.tight_layout()
plt.show()

In [None]:
# --- Cluster size distribution ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].bar(range(len(counts)), sorted(counts, reverse=True), color="steelblue")
axes[0].set_xlabel("Cluster rank")
axes[0].set_ylabel("Size")
axes[0].set_title("Cluster sizes (sorted)")

axes[1].hist(counts, bins=30, color="steelblue", edgecolor="white")
axes[1].set_xlabel("Cluster size")
axes[1].set_ylabel("Count")
axes[1].set_title("Cluster size histogram")

plt.tight_layout()
plt.show()

In [11]:
# --- Hierarchical taxonomy with SMART ---
rho_values = [0.1, 0.4, 0.7]
smart = create_smart_model(
    MODULE_NAME, dim=data_dim, rho_values=rho_values, overrides=PARAM_OVERRIDES
)

X_smart = smart.prepare_data(X_reduced)
smart.fit(X_smart, verbose=True)

hierarchy = extract_hierarchy_labels(smart)
print(f"Hierarchy shape: {hierarchy.shape}  (samples × levels)")

for i, rho in enumerate(rho_values):
    n = len(np.unique(hierarchy[:, i]))
    print(f"  Level {i} (rho={rho}): {n} clusters")

Level 0 (rho=0.1):   0%|          | 0/1000 [00:00<?, ?it/s]

Level 1 (rho=0.4):   0%|          | 0/1000 [00:00<?, ?it/s]

Level 2 (rho=0.7):   0%|          | 0/1000 [00:00<?, ?it/s]

Hierarchy shape: (1000, 3)  (samples × levels)
  Level 0 (rho=0.1): 4 clusters
  Level 1 (rho=0.4): 46 clusters
  Level 2 (rho=0.7): 351 clusters


In [None]:
# --- Interpretation functions ---
from collections import Counter

global_mean_dir = W_dec_full.mean(dim=0)

def get_cluster_tokens(feature_indices, top_k=10):
    """Differential logit lens: top tokens distinctive to this cluster."""
    dirs = W_dec_full[feature_indices]
    diff_dir = dirs.mean(dim=0) - global_mean_dir
    with torch.no_grad():
        normed = final_ln(diff_dir.unsqueeze(0))
        logits = (normed @ unembed_weights.T).squeeze(0)
    topk = logits.topk(top_k)
    return [tokenizer.decode([tid]) for tid in topk.indices.tolist()]

def get_cluster_token_votes(feature_indices, top_k=10):
    """Count per-feature top-5 logit lens tokens across cluster."""
    counter = Counter()
    for fi in feature_indices:
        for tid in top5_ids[fi].tolist():
            counter[tid] += 1
    return [(tokenizer.decode([tid]).strip(), cnt) for tid, cnt in counter.most_common(top_k)]

print("Interpretation functions ready.")

In [None]:
# --- Results Summary (machine-parseable JSON) ---
import json as _json

_l0_ids = sorted(np.unique(hierarchy[:, 0]))
_cluster_summary = {}
for c in _l0_ids:
    mask = hierarchy[:, 0] == c
    feature_indices = idx[mask]
    diff = get_cluster_tokens(feature_indices, top_k=5)
    votes = get_cluster_token_votes(feature_indices, top_k=5)
    _cluster_summary[int(c)] = {
        "size": int(mask.sum()),
        "top_diff_tokens": diff,
        "top_vote_tokens": [tok for tok, cnt in votes],
    }

results_summary = {
    "model": SAE_MODEL,
    "layer": LAYER_IDX,
    "module": MODULE_NAME,
    "pca_dims": PCA_DIMS,
    "quick_test": QUICK_TEST,
    "n_features_clustered": int(X_reduced.shape[0]),
    "rho_values": rho_values,
    "n_clusters_per_level": [
        int(len(np.unique(hierarchy[:, i]))) for i in range(len(rho_values))
    ],
    "clusters_L0": _cluster_summary,
}

print("===RESULTS_JSON===")
print(_json.dumps(results_summary, indent=2))
print("===END_RESULTS===")

In [None]:
# --- L0 Taxonomy ---
n_levels = hierarchy.shape[1]
l0_ids = sorted(np.unique(hierarchy[:, 0]))

print("=" * 70)
print(f"LEVEL 0 — {len(l0_ids)} broad themes (rho={rho_values[0]})")
print("=" * 70)

for c in l0_ids:
    mask = hierarchy[:, 0] == c
    n_feat = mask.sum()
    feature_indices = idx[mask]

    child_info = ""
    if n_levels > 1:
        n_l1 = len(np.unique(hierarchy[mask, 1]))
        child_info = f", {n_l1} L1 children"

    diff_tokens = get_cluster_tokens(feature_indices, top_k=15)
    votes = get_cluster_token_votes(feature_indices, top_k=10)

    print(f"\nL0:{c} ({n_feat} features{child_info}):")
    print(f"  Differential: {diff_tokens}")
    print(f"  Votes:        {', '.join(f'{t}({cnt})' for t, cnt in votes)}")

In [None]:
# --- Explore hierarchy: drill into any cluster's children ---

def explore_cluster(level, cluster_id, top_k=10, max_children=50):
    """Show a cluster and list its children at the next level."""
    mask = hierarchy[:, level] == cluster_id
    n = mask.sum()
    if n == 0:
        print(f"No samples in L{level} cluster {cluster_id}")
        return

    feature_indices = idx[mask]
    diff_tokens = get_cluster_tokens(feature_indices, top_k=top_k)
    votes = get_cluster_token_votes(feature_indices, top_k=top_k)

    print(f"L{level}:{cluster_id} ({n} features, rho={rho_values[level]}):")
    print(f"  Differential: {diff_tokens}")
    print(f"  Votes:        {', '.join(f'{t}({cnt})' for t, cnt in votes)}")

    child_level = level + 1
    if child_level >= hierarchy.shape[1]:
        print("  (finest level — no children)")
        return

    child_ids, child_counts = np.unique(hierarchy[mask, child_level], return_counts=True)
    order = np.argsort(-child_counts)
    print(f"\n  -> {len(child_ids)} children at L{child_level} (rho={rho_values[child_level]}):\n")

    for ci in order[:max_children]:
        cid = child_ids[ci]
        child_mask = mask & (hierarchy[:, child_level] == cid)
        child_features = idx[child_mask]
        child_diff = get_cluster_tokens(child_features, top_k=top_k)
        child_votes = get_cluster_token_votes(child_features, top_k=5)
        vote_str = ', '.join(f'{t}({cnt})' for t, cnt in child_votes[:5])
        print(f"    L{child_level}:{cid} ({child_counts[ci]} feat): {child_diff}")
        print(f"      Votes: {vote_str}")

    if len(child_ids) > max_children:
        print(f"    ... and {len(child_ids) - max_children} more")

# Example: explore Level 0, Cluster 0
explore_cluster(level=0, cluster_id=0)

In [None]:
# --- Vigilance parameter sweep (optional/advanced) ---
# Sweep rho to see how cluster count changes with vigilance.
# Higher rho → more clusters (tighter match required, except BayesianART).

rho_sweep = np.linspace(0.1, 0.95, 10)
cluster_counts = []

for rho in tqdm(rho_sweep, desc="Vigilance sweep"):
    overrides_sweep = dict(PARAM_OVERRIDES, rho=rho)
    m = create_art_module(MODULE_NAME, dim=data_dim, overrides=overrides_sweep)
    X_p = m.prepare_data(X_reduced)
    m.fit(X_p)
    cluster_counts.append(m.n_clusters)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(rho_sweep, cluster_counts, "o-", color="steelblue")
ax.set_xlabel("Vigilance (rho)")
ax.set_ylabel("Number of clusters")
ax.set_title(f"Vigilance sweep — {MODULE_NAME}")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

for rho, nc in zip(rho_sweep, cluster_counts):
    print(f"  rho={rho:.2f} → {nc} clusters")

## Phase 2: Checkpoint Comparison (Stub)

**Goal:** Track how SAE feature categories evolve across Pythia-70M training checkpoints.

**Approach:**
1. Load Pythia-70M at multiple checkpoints (e.g., steps 1000, 10000, 50000, 143000)
2. Extract SAE decoder weights at each checkpoint
3. Use ART's `partial_fit()` to incrementally update the taxonomy
4. Track: resonant features (stable), new categories (plastic), dormant categories (lost)

This leverages ART's core strength — stability-plasticity balance — which SAEs lack entirely (they must retrain from scratch).

*Implementation in Phase 2.*