# ART-Organized SAE Feature Taxonomy

Cluster pre-trained SAE decoder weight vectors from **Pythia** models using Adaptive Resonance Theory (ART) modules to build hierarchical feature taxonomies.

**Pipeline:** SAE W_dec → optional PCA → ART clustering → hierarchical taxonomy via SMART

**Key idea:** ART's vigilance parameter provides a single, principled knob for controlling cluster granularity — something SAEs lack. SMART builds a hierarchy by stacking multiple vigilance levels.

**Supported models:** Pythia-70M (512-dim, 32k features), Pythia-410M (1024-dim, 65k features)

In [1]:
# --- Setup ---
import sys, os

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # Clone or pull the repo, then cd into it
    if os.path.exists("art_utils.py"):
        # Already inside the repo (cell re-run)
        !git pull
    elif os.path.exists("ART-SAE-taxonomy"):
        os.chdir("ART-SAE-taxonomy")
        !git pull
    else:
        !git clone https://github.com/syre-ai/ART-SAE-taxonomy.git
        os.chdir("ART-SAE-taxonomy")
    !pip install -q artlib==0.1.7 eai-sparsify transformers torch umap-learn plotly tqdm pandas

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

# Force-reload local modules so git pull changes take effect
import importlib
import gpu_fuzzy_art, art_utils
importlib.reload(gpu_fuzzy_art)
importlib.reload(art_utils)

from art_utils import (
    create_art_module,
    create_smart_model,
    extract_hierarchy_labels,
    list_available_modules,
)

print("Setup complete.")

Setup complete.


In [8]:
# ============================================================
# CONFIGURATION — edit this cell to change experiments
# ============================================================

# --- Model ---
LM_MODEL = "EleutherAI/pythia-410m"     # LM for logit lens interpretation
SAE_MODEL = "EleutherAI/sae-pythia-410m-65k"  # SAE decoder weights to cluster
LAYER_IDX = 12             # Middle layer (410M has 24 layers, 70M has 6)

# --- Clustering ---
MODULE_NAME = "GPUFuzzyART"  # GPU-accelerated FuzzyART (see list_available_modules())
PARAM_OVERRIDES = {"rho": 0.15}       # e.g. {"rho": 0.8}
PCA_DIMS = None            # Set to int (e.g. 50, 100, 200) to reduce dims; None = use raw dims
QUICK_TEST = True          # Subsample to 1000 features for fast iteration
RANDOM_SEED = 42

In [3]:
# --- Checkpoint Restore (skip Cells 4-6 if checkpoint exists) ---
# Set USE_CHECKPOINT = True to load saved pipeline state instead of recomputing.
# This saves ~15 minutes by skipping SAE loading, logit lens, and junk filtering.

import os, torch

USE_CHECKPOINT = False  # Set True to restore from checkpoint
CHECKPOINT_PATH = "checkpoints/pipeline_state.pt"

if USE_CHECKPOINT and os.path.exists(CHECKPOINT_PATH):
    print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
    ckpt = torch.load(CHECKPOINT_PATH, weights_only=False)

    # Verify config matches
    ckpt_cfg = ckpt["config"]
    assert ckpt_cfg["lm_model"] == LM_MODEL, \
        f"Checkpoint LM mismatch: {ckpt_cfg['lm_model']} vs {LM_MODEL}"
    assert ckpt_cfg["sae_model"] == SAE_MODEL, \
        f"Checkpoint SAE mismatch: {ckpt_cfg['sae_model']} vs {SAE_MODEL}"

    # Restore tensors
    W_dec_full = ckpt["W_dec_full"]
    top5_ids = ckpt["top5_ids"]
    unembed_weights = ckpt["unembed_weights"]
    final_ln_weight = ckpt["final_ln_weight"]
    final_ln_bias = ckpt["final_ln_bias"]
    valid_indices = ckpt["valid_indices"]

    # Rebuild final_ln as a simple module for logit lens
    from transformers import AutoTokenizer, AutoModelForCausalLM
    tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
    # Create a minimal layer norm with saved weights
    import torch.nn as nn
    final_ln = nn.LayerNorm(final_ln_weight.shape[0])
    final_ln.weight.data = final_ln_weight
    final_ln.bias.data = final_ln_bias
    final_ln.eval()

    # Rebuild filtered/subsampled data
    import numpy as np
    W_dec_raw = W_dec_full.numpy()
    kept_indices = valid_indices.numpy()
    W_dec_filtered = W_dec_raw[kept_indices]
    rng = np.random.default_rng(RANDOM_SEED)
    if QUICK_TEST:
        n_sample = min(1000, W_dec_filtered.shape[0])
        sub_idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
        idx_original = kept_indices[sub_idx]
        W_dec = W_dec_filtered[sub_idx]
    else:
        perm = rng.permutation(W_dec_filtered.shape[0])
        idx_original = kept_indices[perm]
        W_dec = W_dec_filtered[perm]
    idx = idx_original

    # Run preprocessing (PCA) inline
    if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
        from sklearn.decomposition import PCA
        pca = PCA(n_components=PCA_DIMS, random_state=RANDOM_SEED)
        X_reduced = pca.fit_transform(W_dec)
    else:
        X_reduced = W_dec.copy()
    data_dim = X_reduced.shape[1]

    print(f"Checkpoint restored: {W_dec_full.shape[0]} features, "
          f"{len(kept_indices)} valid, {X_reduced.shape[0]} selected, "
          f"{data_dim}-dim")
    print(f"Config: {ckpt_cfg}")
    print(">>> Skip Cells 4-7 and proceed to Cell 8 (Create ART module)")
else:
    if USE_CHECKPOINT:
        print(f"Checkpoint not found at {CHECKPOINT_PATH} — run full pipeline (Cells 4-7)")
    else:
        print("Checkpoint restore disabled (USE_CHECKPOINT=False) — run full pipeline")

Checkpoint restore disabled (USE_CHECKPOINT=False) — run full pipeline


In [4]:
# --- Load SAE decoder weights ---
from sparsify import Sae

saes = Sae.load_many(SAE_MODEL)

# Detect key format: "layers.N" (70M) vs "layers.N.mlp" (410M)
layer_key = f"layers.{LAYER_IDX}"
if layer_key not in saes:
    layer_key = f"layers.{LAYER_IDX}.mlp"
W_dec_raw = saes[layer_key].W_dec.detach().numpy()
print(f"Loaded W_dec from {SAE_MODEL}")
print(f"  Layer: {layer_key}, shape: {W_dec_raw.shape} (features × hidden_dim)")
print(f"  Available keys: {sorted(saes.keys())}")

Fetching 49 files:   0%|          | 0/49 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Dropping extra args {'signed': False}


Loaded W_dec from EleutherAI/sae-pythia-410m-65k
  Layer: layers.12.mlp, shape: (65536, 1024) (features × hidden_dim)
  Available keys: ['layers.0.mlp', 'layers.1.mlp', 'layers.10.mlp', 'layers.11.mlp', 'layers.12.mlp', 'layers.13.mlp', 'layers.14.mlp', 'layers.15.mlp', 'layers.16.mlp', 'layers.17.mlp', 'layers.18.mlp', 'layers.19.mlp', 'layers.2.mlp', 'layers.20.mlp', 'layers.21.mlp', 'layers.22.mlp', 'layers.23.mlp', 'layers.3.mlp', 'layers.4.mlp', 'layers.5.mlp', 'layers.6.mlp', 'layers.7.mlp', 'layers.8.mlp', 'layers.9.mlp']


In [5]:
# --- Filter junk features + subsample ---
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load LM (reused by Cell 11 for interpretation)
tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
lm = AutoModelForCausalLM.from_pretrained(LM_MODEL)
lm.eval()
lm.float()

# Pythia models have UNTIED embeddings — use embed_out + final_layer_norm
unembed_weights = lm.embed_out.weight.detach()  # (vocab_size, hidden_dim)
final_ln = lm.gpt_neox.final_layer_norm
print(f"LM: {LM_MODEL}, vocab={unembed_weights.shape[0]}, hidden={unembed_weights.shape[1]}")

# --- Batch logit lens to classify all features (chunked to avoid OOM) ---
W_dec_full = torch.from_numpy(W_dec_raw).float()
n_features = W_dec_full.shape[0]
top5_ids = torch.zeros(n_features, 5, dtype=torch.long)

LOGIT_BATCH = 4096
with torch.no_grad():
    for start in range(0, n_features, LOGIT_BATCH):
        end = min(start + LOGIT_BATCH, n_features)
        normed = final_ln(W_dec_full[start:end])
        logits = normed @ unembed_weights.T
        top5_ids[start:end] = logits.topk(5, dim=1).indices
        del normed, logits  # free memory between chunks
print(f"Logit lens: computed top-5 tokens for {n_features} features")

# --- Junk token classifier (minimal: only truly uninformative tokens) ---
def is_junk_token(token_id):
    decoded = tokenizer.decode([token_id])
    if '\ufffd' in decoded:        return True   # byte-fallback (broken encoding)
    if '<|' in decoded:            return True   # special tokens (<|endoftext|> etc.)
    return False

# Classify each feature: junk if >=3 of top-5 tokens are junk
junk_counts = torch.zeros(n_features, dtype=torch.long)
for i in range(n_features):
    junk_counts[i] = sum(is_junk_token(tid.item()) for tid in top5_ids[i])

is_junk_feature = junk_counts >= 3  # >=60% junk
keep_mask = ~is_junk_feature

n_total = W_dec_raw.shape[0]
n_kept = keep_mask.sum().item()
n_filtered = n_total - n_kept
print(f"Junk filter: {n_total} total → {n_kept} kept, {n_filtered} filtered ({n_filtered/n_total:.1%} junk)")

# Distribution of junk counts
for j in range(6):
    nj = (junk_counts == j).sum().item()
    print(f"  {j}/5 junk tokens: {nj} features")

# --- Apply filter, then subsample/shuffle ---
kept_indices = torch.where(keep_mask)[0].numpy()
W_dec_filtered = W_dec_raw[kept_indices]

rng = np.random.default_rng(RANDOM_SEED)
if QUICK_TEST:
    n_sample = min(1000, W_dec_filtered.shape[0])
    idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
    # Map back to original feature indices
    idx_original = kept_indices[idx]
    W_dec = W_dec_filtered[idx]
    print(f"\nQuick test mode: subsampled to {W_dec.shape[0]} features (from {W_dec_filtered.shape[0]} non-junk)")
else:
    perm = rng.permutation(W_dec_filtered.shape[0])
    idx_original = kept_indices[perm]
    W_dec = W_dec_filtered[perm]
    print(f"\nShuffled {W_dec.shape[0]} non-junk features")

# idx tracks the original feature indices (for logit lens lookups in Cell 11)
idx = idx_original

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]

LM: EleutherAI/pythia-410m, vocab=50304, hidden=1024


Logit lens: computed top-5 tokens for 65536 features


Junk filter: 65536 total → 65534 kept, 2 filtered (0.0% junk)
  0/5 junk tokens: 63560 features
  1/5 junk tokens: 1935 features
  2/5 junk tokens: 39 features
  3/5 junk tokens: 2 features
  4/5 junk tokens: 0 features
  5/5 junk tokens: 0 features

Quick test mode: subsampled to 1000 features (from 65534 non-junk)


In [None]:
# --- Preprocessing: optional PCA ---

norms = np.linalg.norm(W_dec, axis=1)
print(f"W_dec L2 norms — min: {norms.min():.4f}, max: {norms.max():.4f}, "
      f"mean: {norms.mean():.4f}, std: {norms.std():.6f}")

if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
    # Fit full PCA to report variance landscape, then truncate
    max_components = min(W_dec.shape[0], W_dec.shape[1])
    pca_full = PCA(n_components=min(max_components, max(PCA_DIMS, 200)), random_state=RANDOM_SEED)
    pca_full.fit(W_dec)
    cumvar = np.cumsum(pca_full.explained_variance_ratio_)

    print(f"Variance landscape ({len(cumvar)} components computed):")
    for thresh in [0.50, 0.80, 0.90, 0.95, 0.99]:
        n_dims = np.searchsorted(cumvar, thresh) + 1
        if n_dims <= len(cumvar):
            print(f"  {thresh:.0%} variance: {n_dims} dims")
        else:
            print(f"  {thresh:.0%} variance: >{len(cumvar)} dims (max computed: {cumvar[-1]:.1%})")

    # Use the requested number of dims
    X_reduced = pca_full.transform(W_dec)[:, :PCA_DIMS]
    explained = cumvar[PCA_DIMS - 1]
    print(f"\nPCA: {W_dec.shape[1]} → {PCA_DIMS} dims ({explained:.1%} variance retained)")
else:
    X_reduced = W_dec.copy()
    print(f"No PCA applied. Using {X_reduced.shape[1]} dims.")

data_dim = X_reduced.shape[1]
print(f"Final data shape: {X_reduced.shape}")

In [7]:
# --- Checkpoint Save ---
# Saves pipeline state so future runs can skip Cells 4-6 (~15 min → ~3 sec).
# Run this after Cells 4-6 complete successfully.

import torch, os

os.makedirs("checkpoints", exist_ok=True)
CHECKPOINT_PATH = "checkpoints/pipeline_state.pt"

checkpoint = {
    "W_dec_full": W_dec_full,
    "top5_ids": top5_ids,
    "unembed_weights": unembed_weights,
    "final_ln_weight": final_ln.weight.data.clone(),
    "final_ln_bias": final_ln.bias.data.clone(),
    "valid_indices": torch.from_numpy(kept_indices),
    "config": {
        "lm_model": LM_MODEL,
        "sae_model": SAE_MODEL,
        "layer_idx": LAYER_IDX,
        "n_features": W_dec_full.shape[0],
        "hidden_dim": W_dec_full.shape[1],
    }
}

torch.save(checkpoint, CHECKPOINT_PATH)
file_size = os.path.getsize(CHECKPOINT_PATH) / (1024 * 1024)
print(f"Checkpoint saved to {CHECKPOINT_PATH} ({file_size:.1f} MB)")
print(f"  W_dec_full: {W_dec_full.shape}")
print(f"  top5_ids: {top5_ids.shape}")
print(f"  valid_indices: {len(kept_indices)} features")
print(f"To restore: set USE_CHECKPOINT=True in Cell 2, then run Cell 3")

Checkpoint saved to checkpoints/pipeline_state.pt (455.5 MB)
  W_dec_full: torch.Size([65536, 1024])
  top5_ids: torch.Size([65536, 5])
  valid_indices: 65534 features
To restore: set USE_CHECKPOINT=True in Cell 2, then run Cell 3


In [None]:
# ============================================================
# PCA DIMENSION SWEEP
# ============================================================
# Tests different PCA dimensions to find the best setting for
# cluster interpretability. Runs SMART hierarchy + differential
# logit lens for each setting, then prints a comparison.
#
# Requires cells 1-5 to have run (W_dec, W_dec_full, idx, etc.)
# ============================================================

import json
import time

# --- Sweep configuration ---
PCA_DIMS_SWEEP = [None, 32, 64, 128, 256]
RHO_VALUES = [0.2, 0.5, 0.8]
N_DRILL = 3     # Drill into top N L0 clusters to show L1 children
TOP_K = 10      # Tokens per cluster interpretation

# --- Interpretation function ---
global_mean_dir = W_dec_full.mean(dim=0)

def _get_tokens(feature_indices, top_k=10):
    """Differential logit lens."""
    dirs = W_dec_full[feature_indices]
    diff_dir = dirs.mean(dim=0) - global_mean_dir
    with torch.no_grad():
        normed = final_ln(diff_dir.unsqueeze(0))
        logits = (normed @ unembed_weights.T).squeeze(0)
    topk = logits.topk(top_k)
    return [tokenizer.decode([tid]) for tid in topk.indices.tolist()]

# --- PCA variance landscape (compute once) ---
print("Computing PCA variance landscape...")
max_pca = min(W_dec.shape[0], W_dec.shape[1])
pca_landscape = PCA(n_components=min(max_pca, 512), random_state=RANDOM_SEED)
pca_landscape.fit(W_dec)
cumvar_full = np.cumsum(pca_landscape.explained_variance_ratio_)

print(f"Variance landscape ({len(cumvar_full)} components):")
for thresh in [0.50, 0.80, 0.90, 0.95, 0.99]:
    n_dims = np.searchsorted(cumvar_full, thresh) + 1
    if n_dims <= len(cumvar_full):
        print(f"  {thresh:.0%}: {n_dims} dims")
    else:
        print(f"  {thresh:.0%}: >{len(cumvar_full)} dims (max: {cumvar_full[-1]:.1%})")

# --- Run sweep ---
all_results = {}

for pca_dims in PCA_DIMS_SWEEP:
    print("\n" + "=" * 70)
    if pca_dims is None:
        label = f"None (raw {W_dec.shape[1]}d)"
        X_sweep = W_dec.copy()
        var_explained = 1.0
    else:
        var_explained = float(cumvar_full[pca_dims - 1])
        label = f"{pca_dims}d ({var_explained:.1%} var)"
        X_sweep = pca_landscape.transform(W_dec)[:, :pca_dims]
    print(f"PCA = {label}")
    print("=" * 70)

    dim = X_sweep.shape[1]
    t0 = time.time()

    # Create and fit SMART
    smart_sweep = create_smart_model(MODULE_NAME, dim=dim, rho_values=RHO_VALUES)
    X_prepared = smart_sweep.prepare_data(X_sweep)
    smart_sweep.fit(X_prepared, verbose=True)

    hierarchy_sweep = extract_hierarchy_labels(smart_sweep)
    fit_time = time.time() - t0

    # Cluster counts
    cluster_counts = []
    for i, rho in enumerate(RHO_VALUES):
        n = len(np.unique(hierarchy_sweep[:, i]))
        cluster_counts.append(n)
    print(f"  Clusters: L0={cluster_counts[0]}, L1={cluster_counts[1]}, L2={cluster_counts[2]}  ({fit_time:.1f}s)")

    # L0 taxonomy
    l0_ids = sorted(np.unique(hierarchy_sweep[:, 0]))
    l0_results = {}

    print(f"\n  L0 CLUSTERS ({len(l0_ids)}):")
    for c in l0_ids:
        mask = hierarchy_sweep[:, 0] == c
        n_feat = int(mask.sum())
        feature_indices = idx[mask]
        n_l1 = len(np.unique(hierarchy_sweep[mask, 1]))
        tokens = _get_tokens(feature_indices, top_k=TOP_K)
        l0_results[int(c)] = {"size": n_feat, "n_l1": n_l1, "tokens": tokens}
        print(f"    L0:{c} ({n_feat} feat, {n_l1} L1): {tokens}")

    # Drill into top N L0 clusters by size
    l0_by_size = sorted(l0_ids, key=lambda c: -(hierarchy_sweep[:, 0] == c).sum())

    for c in l0_by_size[:N_DRILL]:
        mask = hierarchy_sweep[:, 0] == c
        n_feat = int(mask.sum())
        feature_indices = idx[mask]
        tokens = _get_tokens(feature_indices, top_k=TOP_K)

        print(f"\n  DRILL L0:{c} ({n_feat} feat): {tokens}")

        child_ids, child_counts = np.unique(hierarchy_sweep[mask, 1], return_counts=True)
        order = np.argsort(-child_counts)
        max_show = min(10, len(child_ids))
        print(f"  Top {max_show}/{len(child_ids)} L1 children:")

        for ci in order[:max_show]:
            cid = child_ids[ci]
            child_mask = mask & (hierarchy_sweep[:, 1] == cid)
            child_features = idx[child_mask]
            child_tokens = _get_tokens(child_features, top_k=TOP_K)
            print(f"    L1:{cid} ({child_counts[ci]} feat): {child_tokens}")

    # Store results
    pca_key = str(pca_dims) if pca_dims is not None else "None"
    all_results[pca_key] = {
        "pca_dims": pca_dims,
        "var_explained": round(float(var_explained), 4),
        "cluster_counts": cluster_counts,
        "fit_time": round(fit_time, 1),
        "l0_clusters": l0_results,
    }

# --- Summary comparison ---
print("\n\n" + "=" * 70)
print("SWEEP SUMMARY")
print("=" * 70)
header = f"{'PCA':<10} {'Var%':<8} {'L0':<5} {'L1':<5} {'L2':<6} {'Time':<7} {'Largest L0 cluster tokens'}"
print(header)
print("-" * len(header))
for pca_key, res in all_results.items():
    cc = res["cluster_counts"]
    largest = max(res["l0_clusters"].values(), key=lambda x: x["size"])
    pct = largest["size"] / sum(v["size"] for v in res["l0_clusters"].values())
    tok_str = ", ".join(largest["tokens"][:5])
    print(f"{pca_key:<10} {res['var_explained']:.1%}   {cc[0]:<5} {cc[1]:<5} {cc[2]:<6} {res['fit_time']:.1f}s   [{pct:.0%}] {tok_str}")

print(f"\n[%] = fraction of all features in the largest L0 cluster (lower = less catch-all)")

# Save results
with open("pca_sweep_results.json", "w") as f:
    json.dump(all_results, f, indent=2, default=str)
print(f"Results saved to pca_sweep_results.json")
print("\nDone! Review the L0 cluster tokens above to identify which PCA setting")
print("produces the most semantically distinct and interpretable clusters.")

In [None]:
# (Skipped — flat clustering now handled by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — UMAP visualization now handled by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — cluster size distribution now handled by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — SMART hierarchy now handled by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — interpretation function now defined in PCA sweep Cell 8)
pass

In [None]:
# (Skipped — results JSON now saved by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — L0 taxonomy now displayed by PCA sweep in Cell 8)
pass

In [None]:
# (Skipped — explore_cluster now handled by PCA sweep drill-down in Cell 8)
pass

In [None]:
# (Skipped — vigilance sweep subsumed by PCA sweep in Cell 8)
pass

## Phase 2: Checkpoint Comparison (Stub)

**Goal:** Track how SAE feature categories evolve across Pythia-70M training checkpoints.

**Approach:**
1. Load Pythia-70M at multiple checkpoints (e.g., steps 1000, 10000, 50000, 143000)
2. Extract SAE decoder weights at each checkpoint
3. Use ART's `partial_fit()` to incrementally update the taxonomy
4. Track: resonant features (stable), new categories (plastic), dormant categories (lost)

This leverages ART's core strength — stability-plasticity balance — which SAEs lack entirely (they must retrain from scratch).

*Implementation in Phase 2.*