# ART-Organized SAE Feature Taxonomy

Cluster pre-trained SAE decoder weight vectors from **Pythia** models using Adaptive Resonance Theory (ART) modules to build hierarchical feature taxonomies.

**Pipeline:** SAE W_dec → optional PCA → ART clustering → hierarchical taxonomy via SMART

**Key idea:** ART's vigilance parameter provides a single, principled knob for controlling cluster granularity — something SAEs lack. SMART builds a hierarchy by stacking multiple vigilance levels.

**Supported models:** Pythia-70M (512-dim, 32k features), Pythia-410M (1024-dim, 65k features)

In [None]:
# --- Setup ---
import sys, os

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # Clone or pull the repo, then cd into it
    if os.path.exists("art_utils.py"):
        # Already inside the repo (cell re-run)
        !git pull
    elif os.path.exists("ART-SAE-taxonomy"):
        os.chdir("ART-SAE-taxonomy")
        !git pull
    else:
        !git clone https://github.com/syre-ai/ART-SAE-taxonomy.git
        os.chdir("ART-SAE-taxonomy")
    !pip install -q artlib==0.1.7 eai-sparsify transformers torch umap-learn plotly tqdm pandas

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

# Force-reload local modules so git pull changes take effect
import importlib
import gpu_fuzzy_art, art_utils
importlib.reload(gpu_fuzzy_art)
importlib.reload(art_utils)

from art_utils import (
    create_art_module,
    create_smart_model,
    extract_hierarchy_labels,
    list_available_modules,
)

print("Setup complete.")

In [None]:
# ============================================================
# CONFIGURATION — edit this cell to change experiments
# ============================================================

# --- Model ---
LM_MODEL = "EleutherAI/pythia-410m"     # LM for logit lens interpretation
SAE_MODEL = "EleutherAI/sae-pythia-410m-65k"  # SAE decoder weights to cluster
LAYER_IDX = 12             # Middle layer (410M has 24 layers, 70M has 6)

# --- Clustering ---
MODULE_NAME = "GPUFuzzyART"  # GPU-accelerated FuzzyART (see list_available_modules())
PARAM_OVERRIDES = {}       # e.g. {"rho": 0.8}
PCA_DIMS = None            # Set to int (e.g. 50, 100, 200) to reduce dims; None = use raw dims
QUICK_TEST = True          # Subsample to 1000 features for fast iteration
RANDOM_SEED = 42

In [None]:
# --- Load SAE decoder weights ---
from sparsify import Sae

saes = Sae.load_many(SAE_MODEL)

# Detect key format: "layers.N" (70M) vs "layers.N.mlp" (410M)
layer_key = f"layers.{LAYER_IDX}"
if layer_key not in saes:
    layer_key = f"layers.{LAYER_IDX}.mlp"
W_dec_raw = saes[layer_key].W_dec.detach().numpy()
print(f"Loaded W_dec from {SAE_MODEL}")
print(f"  Layer: {layer_key}, shape: {W_dec_raw.shape} (features × hidden_dim)")
print(f"  Available keys: {sorted(saes.keys())}")

In [None]:
# --- Filter junk features + subsample ---
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load LM (reused by Cell 11 for interpretation)
tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
lm = AutoModelForCausalLM.from_pretrained(LM_MODEL)
lm.eval()
lm.float()

# Pythia models have UNTIED embeddings — use embed_out + final_layer_norm
unembed_weights = lm.embed_out.weight.detach()  # (vocab_size, hidden_dim)
final_ln = lm.gpt_neox.final_layer_norm
print(f"LM: {LM_MODEL}, vocab={unembed_weights.shape[0]}, hidden={unembed_weights.shape[1]}")

# --- Batch logit lens to classify all features (chunked to avoid OOM) ---
W_dec_full = torch.from_numpy(W_dec_raw).float()
n_features = W_dec_full.shape[0]
top5_ids = torch.zeros(n_features, 5, dtype=torch.long)

LOGIT_BATCH = 4096
with torch.no_grad():
    for start in range(0, n_features, LOGIT_BATCH):
        end = min(start + LOGIT_BATCH, n_features)
        normed = final_ln(W_dec_full[start:end])
        logits = normed @ unembed_weights.T
        top5_ids[start:end] = logits.topk(5, dim=1).indices
        del normed, logits  # free memory between chunks
print(f"Logit lens: computed top-5 tokens for {n_features} features")

# --- Junk token classifier (minimal: only truly uninformative tokens) ---
def is_junk_token(token_id):
    decoded = tokenizer.decode([token_id])
    if '\ufffd' in decoded:        return True   # byte-fallback (broken encoding)
    if '<|' in decoded:            return True   # special tokens (<|endoftext|> etc.)
    return False

# Classify each feature: junk if >=3 of top-5 tokens are junk
junk_counts = torch.zeros(n_features, dtype=torch.long)
for i in range(n_features):
    junk_counts[i] = sum(is_junk_token(tid.item()) for tid in top5_ids[i])

is_junk_feature = junk_counts >= 3  # >=60% junk
keep_mask = ~is_junk_feature

n_total = W_dec_raw.shape[0]
n_kept = keep_mask.sum().item()
n_filtered = n_total - n_kept
print(f"Junk filter: {n_total} total → {n_kept} kept, {n_filtered} filtered ({n_filtered/n_total:.1%} junk)")

# Distribution of junk counts
for j in range(6):
    nj = (junk_counts == j).sum().item()
    print(f"  {j}/5 junk tokens: {nj} features")

# --- Apply filter, then subsample/shuffle ---
kept_indices = torch.where(keep_mask)[0].numpy()
W_dec_filtered = W_dec_raw[kept_indices]

rng = np.random.default_rng(RANDOM_SEED)
if QUICK_TEST:
    n_sample = min(1000, W_dec_filtered.shape[0])
    idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
    # Map back to original feature indices
    idx_original = kept_indices[idx]
    W_dec = W_dec_filtered[idx]
    print(f"\nQuick test mode: subsampled to {W_dec.shape[0]} features (from {W_dec_filtered.shape[0]} non-junk)")
else:
    perm = rng.permutation(W_dec_filtered.shape[0])
    idx_original = kept_indices[perm]
    W_dec = W_dec_filtered[perm]
    print(f"\nShuffled {W_dec.shape[0]} non-junk features")

# idx tracks the original feature indices (for logit lens lookups in Cell 11)
idx = idx_original

In [None]:
# --- Preprocessing: optional PCA (no StandardScaler — W_dec rows are already unit-normalized) ---

# Diagnostic: confirm W_dec rows are unit-normalized
norms = np.linalg.norm(W_dec, axis=1)
print(f"W_dec L2 norms — min: {norms.min():.4f}, max: {norms.max():.4f}, "
      f"mean: {norms.mean():.4f}, std: {norms.std():.6f}")

if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
    # PCA centers internally (subtracts mean), which is fine for unit-norm data
    pca = PCA(n_components=PCA_DIMS, random_state=RANDOM_SEED)
    X_reduced = pca.fit_transform(W_dec)
    explained = pca.explained_variance_ratio_.sum()
    print(f"PCA: {W_dec.shape[1]} → {PCA_DIMS} dims ({explained:.1%} variance explained)")

    # Variance diagnostic: dims needed for key thresholds
    cumvar = np.cumsum(pca.explained_variance_ratio_)
    for thresh in [0.80, 0.90, 0.95, 0.99]:
        n_dims = np.searchsorted(cumvar, thresh) + 1
        print(f"  Dims for {thresh:.0%} variance: {n_dims}")
else:
    X_reduced = W_dec.copy()
    print(f"No PCA applied. Using {X_reduced.shape[1]} dims.")

data_dim = X_reduced.shape[1]
print(f"Final data shape: {X_reduced.shape}")

In [None]:
# --- Create ART module ---
list_available_modules()
print()

model = create_art_module(MODULE_NAME, dim=data_dim, overrides=PARAM_OVERRIDES)
print(f"\nCreated {MODULE_NAME} for {data_dim}-dim data")
print(f"Parameters: {model.params}" if hasattr(model, 'params') else "")

In [None]:
# --- Flat clustering ---
X_prepared = model.prepare_data(X_reduced)
print(f"Data after prepare_data: shape {X_prepared.shape}")

model.fit(X_prepared, verbose=True)

labels = model.labels_
n_clusters = model.n_clusters
print(f"\nClusters found: {n_clusters}")
print(f"Samples: {len(labels)}")

# Cluster size stats
unique, counts = np.unique(labels, return_counts=True)
print(f"Cluster sizes — min: {counts.min()}, max: {counts.max()}, "
      f"median: {np.median(counts):.0f}, mean: {counts.mean():.1f}")

In [None]:
# --- UMAP visualization of flat clusters ---
import umap

reducer = umap.UMAP(n_components=2, random_state=RANDOM_SEED, n_neighbors=15)
embedding = reducer.fit_transform(X_reduced)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    embedding[:, 0], embedding[:, 1],
    c=labels, cmap="tab20", s=5, alpha=0.6
)
ax.set_title(f"{MODULE_NAME} Clusters (n={n_clusters}) — UMAP projection")
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
plt.colorbar(scatter, ax=ax, label="Cluster ID")
plt.tight_layout()
plt.show()

In [None]:
# --- Cluster size distribution ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].bar(range(len(counts)), sorted(counts, reverse=True), color="steelblue")
axes[0].set_xlabel("Cluster rank")
axes[0].set_ylabel("Size")
axes[0].set_title("Cluster sizes (sorted)")

axes[1].hist(counts, bins=30, color="steelblue", edgecolor="white")
axes[1].set_xlabel("Cluster size")
axes[1].set_ylabel("Count")
axes[1].set_title("Cluster size histogram")

plt.tight_layout()
plt.show()

In [None]:
# --- Hierarchical taxonomy with SMART ---
rho_values = [0.1, 0.4, 0.7]
smart = create_smart_model(
    MODULE_NAME, dim=data_dim, rho_values=rho_values, overrides=PARAM_OVERRIDES
)

X_smart = smart.prepare_data(X_reduced)
smart.fit(X_smart, verbose=True)

hierarchy = extract_hierarchy_labels(smart)
print(f"Hierarchy shape: {hierarchy.shape}  (samples × levels)")

for i, rho in enumerate(rho_values):
    n = len(np.unique(hierarchy[:, i]))
    print(f"  Level {i} (rho={rho}): {n} clusters")

In [None]:
# --- Cluster Analysis: specificity, confidence, quality, interpretation ---
# T1.1: Specificity (bounding box or radius-based)
# T1.2: Targeted logit lens (tight/distinctive dims)
# T1.3: Confidence scores (when available)
# T2.1: Token voting
# T2.2: Quality metrics (cohesion, silhouette)

from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter

# ---- Detect module type ----
_is_gpu_fuzzy = hasattr(smart.modules[0], 'get_bounding_boxes')

# ---- Differential logit lens ----
global_mean_dir = W_dec_full.mean(dim=0)

def get_cluster_tokens(feature_indices, top_k=10):
    """Differential logit lens: top tokens distinctive to this cluster."""
    dirs = W_dec_full[feature_indices]
    diff_dir = dirs.mean(dim=0) - global_mean_dir
    with torch.no_grad():
        normed = final_ln(diff_dir.unsqueeze(0))
        logits = (normed @ unembed_weights.T).squeeze(0)
    topk = logits.topk(top_k)
    return [tokenizer.decode([tid]) for tid in topk.indices.tolist()]

# ---- Specificity per level ----
specificities = {}
bbox_widths = {}  # Only populated for GPUFuzzyART

for lv, module in enumerate(smart.modules):
    if _is_gpu_fuzzy:
        lower, upper, width = module.get_bounding_boxes()
        spec = 1.0 - width.mean(dim=1)
        for cid in sorted(np.unique(hierarchy[:, lv])):
            specificities[(lv, cid)] = spec[cid].item()
            bbox_widths[(lv, cid)] = width[cid]
    else:
        # HypersphereART / other artlib modules: specificity from radius
        r_hat = module.params.get('r_hat', 1.0)
        for cid in sorted(np.unique(hierarchy[:, lv])):
            if hasattr(module, 'W') and cid < len(module.W):
                radius = module.W[cid][-1]
                specificities[(lv, cid)] = max(0.0, 1.0 - radius / r_hat)
            else:
                specificities[(lv, cid)] = 0.0

print("Specificity:")
for lv in range(len(rho_values)):
    specs = [specificities[(lv, c)] for c in sorted(np.unique(hierarchy[:, lv]))]
    print(f"  L{lv}: min={min(specs):.3f}, max={max(specs):.3f}, mean={np.mean(specs):.3f}")

# ---- Confidence scores per level ----
confidence = {}
for lv, module in enumerate(smart.modules):
    if _is_gpu_fuzzy:
        labels_lv, match_lv, margin_lv = module.predict_with_confidence(X_smart)
        confidence[lv] = (labels_lv, match_lv, margin_lv)
        print(f"Confidence L{lv}: M=[{match_lv.min():.3f}, {match_lv.max():.3f}], "
              f"mean T-margin={margin_lv.mean():.4f}")
    else:
        labels_lv = module.labels_
        confidence[lv] = (labels_lv, None, None)
        print(f"Confidence L{lv}: {module.n_clusters} clusters (match/margin N/A for artlib modules)")

# ---- Cohesion + silhouette ----
cohesions = {}
for lv in range(len(rho_values)):
    for c in sorted(np.unique(hierarchy[:, lv])):
        mask = hierarchy[:, lv] == c
        n = mask.sum()
        if n < 2:
            cohesions[(lv, c)] = 1.0
            continue
        cluster_data = X_reduced[mask]
        if n > 500:
            sub = np.random.choice(n, 500, replace=False)
            cluster_data = cluster_data[sub]
        sim = cosine_similarity(cluster_data)
        k = len(cluster_data)
        cohesions[(lv, c)] = float((sim.sum() - k) / (k * (k - 1)))

sil_scores = {}
for lv in range(len(rho_values)):
    n_unique = len(np.unique(hierarchy[:, lv]))
    if n_unique > 1:
        sil = silhouette_score(
            X_reduced, hierarchy[:, lv], metric='cosine',
            sample_size=min(5000, len(X_reduced))
        )
        sil_scores[lv] = sil
        print(f"Silhouette L{lv} (cosine): {sil:.3f}")

# ---- Token voting (unfiltered — all tokens contribute) ----
def get_cluster_token_votes(feature_indices, top_k=10):
    """Count per-feature top-5 logit lens tokens across cluster."""
    counter = Counter()
    for fi in feature_indices:
        for tid in top5_ids[fi].tolist():
            counter[tid] += 1
    return [(tokenizer.decode([tid]).strip(), cnt) for tid, cnt in counter.most_common(top_k)]

# ---- Targeted logit lens via selective dimensions ----
def get_cluster_tokens_targeted(feature_indices, level, cluster_id, top_k=10, n_tight=100):
    """Project through the most selective dimensions."""
    if (level, cluster_id) in bbox_widths:
        # GPUFuzzyART: use tightest bounding box dimensions
        width = bbox_widths[(level, cluster_id)]
        tight_dims = width.argsort()[:n_tight]
    else:
        # HypersphereART/other: use dims where cluster mean deviates most from global mean
        dirs = W_dec_full[feature_indices]
        mean_dir = dirs.mean(dim=0)
        deviation = (mean_dir - global_mean_dir).abs()
        tight_dims = deviation.argsort(descending=True)[:n_tight]

    dirs = W_dec_full[feature_indices]
    mean_dir = dirs.mean(dim=0)
    masked_dir = torch.zeros_like(mean_dir)
    masked_dir[tight_dims] = mean_dir[tight_dims]

    with torch.no_grad():
        normed = final_ln(masked_dir.unsqueeze(0))
        logits = (normed @ unembed_weights.T).squeeze(0)
    topk = logits.topk(top_k)
    return [tokenizer.decode([tid]) for tid in topk.indices.tolist()]

# ---- Scatter plot: size vs specificity + size vs cohesion ----
l0_ids_plot = sorted(np.unique(hierarchy[:, 0]))
sizes = [np.sum(hierarchy[:, 0] == c) for c in l0_ids_plot]
specs_plot = [specificities[(0, c)] for c in l0_ids_plot]
cohs_plot = [cohesions[(0, c)] for c in l0_ids_plot]

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].scatter(sizes, specs_plot, c='steelblue', s=60, edgecolors='k', linewidths=0.5)
for c, s, sp in zip(l0_ids_plot, sizes, specs_plot):
    axes[0].annotate(f'L0:{c}', (s, sp), fontsize=8, ha='left', va='bottom')
axes[0].set_xlabel('Cluster size')
axes[0].set_ylabel('Specificity')
axes[0].set_title('L0: Size vs Specificity')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(sizes, cohs_plot, c='coral', s=60, edgecolors='k', linewidths=0.5)
for c, s, co in zip(l0_ids_plot, sizes, cohs_plot):
    axes[1].annotate(f'L0:{c}', (s, co), fontsize=8, ha='left', va='bottom')
axes[1].set_xlabel('Cluster size')
axes[1].set_ylabel('Cohesion (mean cosine sim)')
axes[1].set_title('L0: Size vs Cohesion')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nAnalysis setup complete.")

In [None]:
# --- L0 Taxonomy with enriched metrics ---
n_levels = hierarchy.shape[1]
l0_ids = sorted(np.unique(hierarchy[:, 0]))
_, l0_match, l0_margin = confidence[0]

print("=" * 85)
print(f"LEVEL 0 — {len(l0_ids)} broad themes (rho={rho_values[0]})")
if 0 in sil_scores:
    print(f"Global silhouette (cosine): {sil_scores[0]:.3f}")
print("=" * 85)

for c in l0_ids:
    mask = hierarchy[:, 0] == c
    n_feat = mask.sum()
    feature_indices = idx[mask]

    # Children count
    child_info = ""
    if n_levels > 1:
        n_l1 = len(np.unique(hierarchy[mask, 1]))
        child_info = f", {n_l1} L1 children"

    # Metrics
    spec = specificities[(0, c)]
    coh = cohesions[(0, c)]

    # Three interpretation methods
    diff_tokens = get_cluster_tokens(feature_indices, top_k=15)
    votes = get_cluster_token_votes(feature_indices, top_k=10)
    targeted = get_cluster_tokens_targeted(feature_indices, level=0, cluster_id=c, top_k=10)

    print(f"\nL0:{c} ({n_feat} features{child_info}):")
    metric_parts = [f"Spec: {spec:.3f}", f"Cohesion: {coh:.3f}"]
    if l0_match is not None:
        m_vals = l0_match[mask]
        t_vals = l0_margin[mask]
        metric_parts.append(f"Match: {m_vals.mean():.3f} (min {m_vals.min():.3f})")
        metric_parts.append(f"T-margin: {t_vals.mean():.4f}")
    print(f"  {'  |  '.join(metric_parts)}")
    print(f"  Differential: {diff_tokens}")
    print(f"  Votes:        {', '.join(f'{t}({cnt})' for t, cnt in votes)}")
    print(f"  Targeted:     {targeted}")

In [None]:
# --- Explore hierarchy: drill into any cluster's children ---

def explore_cluster(level, cluster_id, top_k=10, max_children=50):
    """Show enriched metrics for a cluster and list its children."""
    mask = hierarchy[:, level] == cluster_id
    n = mask.sum()
    if n == 0:
        print(f"No samples in L{level} cluster {cluster_id}")
        return

    feature_indices = idx[mask]
    diff_tokens = get_cluster_tokens(feature_indices, top_k=top_k)
    votes = get_cluster_token_votes(feature_indices, top_k=top_k)
    targeted = get_cluster_tokens_targeted(
        feature_indices, level=level, cluster_id=cluster_id, top_k=top_k
    )

    spec = specificities.get((level, cluster_id), 0)
    coh = cohesions.get((level, cluster_id), 0)
    _, m_all, t_all = confidence[level]

    print(f"L{level}:{cluster_id} ({n} features, rho={rho_values[level]}):")
    metric_parts = [f"Spec: {spec:.3f}", f"Cohesion: {coh:.3f}"]
    if m_all is not None:
        m_vals = m_all[mask]
        t_vals = t_all[mask]
        metric_parts.append(f"Match: {m_vals.mean():.3f} (min {m_vals.min():.3f})")
        metric_parts.append(f"T-margin: {t_vals.mean():.4f}")
    print(f"  {'  |  '.join(metric_parts)}")
    print(f"  Differential: {diff_tokens}")
    print(f"  Votes:        {', '.join(f'{t}({cnt})' for t, cnt in votes)}")
    print(f"  Targeted:     {targeted}")

    child_level = level + 1
    if child_level >= hierarchy.shape[1]:
        print("  (finest level — no children)")
        return

    child_ids, child_counts = np.unique(hierarchy[mask, child_level], return_counts=True)
    order = np.argsort(-child_counts)
    print(f"\n  -> {len(child_ids)} children at L{child_level} (rho={rho_values[child_level]}):\n")

    for rank, ci in enumerate(order[:max_children]):
        cid = child_ids[ci]
        child_mask = mask & (hierarchy[:, child_level] == cid)
        child_features = idx[child_mask]
        child_diff = get_cluster_tokens(child_features, top_k=top_k)
        child_votes = get_cluster_token_votes(child_features, top_k=5)
        child_spec = specificities.get((child_level, cid), 0)
        child_coh = cohesions.get((child_level, cid), 0)
        vote_str = ', '.join(f'{t}({cnt})' for t, cnt in child_votes[:5])
        print(f"    L{child_level}:{cid} ({child_counts[ci]} feat, "
              f"spec={child_spec:.3f}, coh={child_coh:.3f}): {child_diff}")
        print(f"      Votes: {vote_str}")

    if len(child_ids) > max_children:
        print(f"    ... and {len(child_ids) - max_children} more")


# Example: explore Level 0, Cluster 0 — change these to explore different parts
explore_cluster(level=0, cluster_id=0)

# To drill deeper, call again with a child:
# explore_cluster(level=1, cluster_id=<child_id_from_above>)

In [None]:
# --- Vigilance parameter sweep (optional/advanced) ---
# Sweep rho to see how cluster count changes with vigilance.
# Higher rho → more clusters (tighter match required, except BayesianART).

rho_sweep = np.linspace(0.1, 0.95, 10)
cluster_counts = []

for rho in tqdm(rho_sweep, desc="Vigilance sweep"):
    overrides_sweep = dict(PARAM_OVERRIDES, rho=rho)
    m = create_art_module(MODULE_NAME, dim=data_dim, overrides=overrides_sweep)
    X_p = m.prepare_data(X_reduced)
    m.fit(X_p)
    cluster_counts.append(m.n_clusters)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(rho_sweep, cluster_counts, "o-", color="steelblue")
ax.set_xlabel("Vigilance (rho)")
ax.set_ylabel("Number of clusters")
ax.set_title(f"Vigilance sweep — {MODULE_NAME}")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

for rho, nc in zip(rho_sweep, cluster_counts):
    print(f"  rho={rho:.2f} → {nc} clusters")

## Phase 2: Checkpoint Comparison (Stub)

**Goal:** Track how SAE feature categories evolve across Pythia-70M training checkpoints.

**Approach:**
1. Load Pythia-70M at multiple checkpoints (e.g., steps 1000, 10000, 50000, 143000)
2. Extract SAE decoder weights at each checkpoint
3. Use ART's `partial_fit()` to incrementally update the taxonomy
4. Track: resonant features (stable), new categories (plastic), dormant categories (lost)

This leverages ART's core strength — stability-plasticity balance — which SAEs lack entirely (they must retrain from scratch).

*Implementation in Phase 2.*