# ART-Organized SAE Feature Taxonomy

Cluster pre-trained SAE decoder weight vectors from **Pythia-70M** using Adaptive Resonance Theory (ART) modules to build hierarchical feature taxonomies.

**Pipeline:** SAE W_dec (32K × 512) → PCA → ART clustering → hierarchical taxonomy via SMART

**Key idea:** ART's vigilance parameter provides a single, principled knob for controlling cluster granularity — something SAEs lack. SMART builds a hierarchy by stacking multiple vigilance levels.

In [1]:
# --- Setup ---
import sys, os

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # Clone or pull the repo, then cd into it
    if os.path.exists("art_utils.py"):
        # Already inside the repo (cell re-run)
        !git pull
    elif os.path.exists("ART-SAE-taxonomy"):
        os.chdir("ART-SAE-taxonomy")
        !git pull
    else:
        !git clone https://github.com/syre-ai/ART-SAE-taxonomy.git
        os.chdir("ART-SAE-taxonomy")
    !pip install -q artlib==0.1.7 eai-sparsify transformers torch umap-learn plotly tqdm pandas

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

from art_utils import (
    create_art_module,
    create_smart_model,
    extract_hierarchy_labels,
    list_available_modules,
)

print("Setup complete.")

Cloning into 'ART-SAE-taxonomy'...
remote: Enumerating objects: 40, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 40 (delta 23), reused 29 (delta 12), pack-reused 0 (from 0)[K
Receiving objects: 100% (40/40), 21.02 KiB | 4.20 MiB/s, done.
Resolving deltas: 100% (23/23), done.
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m454.4 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.1/55.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hSetup complete.


In [2]:
# ============================================================
# CONFIGURATION — edit this cell to change experiments
# ============================================================

LAYER_IDX = 3              # Pythia-70M layer (0-5)
MODULE_NAME = "GPUFuzzyART"  # GPU-accelerated FuzzyART (see list_available_modules())
PARAM_OVERRIDES = {}       # e.g. {"rho": 0.8}
PCA_DIMS = None            # Set to int (e.g. 50, 100, 200) to reduce dims; None = use raw 512-dim
QUICK_TEST = True          # Subsample to 1000 features for fast iteration
RANDOM_SEED = 42

In [None]:
# --- Load SAE decoder weights ---
from sparsify import Sae

saes = Sae.load_many("EleutherAI/sae-pythia-70m-32k")
layer_key = f"layers.{LAYER_IDX}"
W_dec_raw = saes[layer_key].W_dec.detach().numpy()  # (32768, 512)
print(f"Loaded W_dec for {layer_key}: shape {W_dec_raw.shape}")

In [None]:
# --- Filter junk features + subsample ---
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load LM (reused by Cell 11 for interpretation)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
lm = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
lm.eval()
lm.float()

# Pythia-70M has UNTIED embeddings — use embed_out + final_layer_norm
unembed_weights = lm.embed_out.weight.detach()  # (50304, 512)
final_ln = lm.gpt_neox.final_layer_norm

# --- Batch logit lens to classify all features ---
W_dec_full = torch.from_numpy(W_dec_raw).float()  # (32768, 512)

with torch.no_grad():
    normed = final_ln(W_dec_full)                    # (32768, 512)
    logits = normed @ unembed_weights.T              # (32768, 50304)
    top5_ids = logits.topk(5, dim=1).indices         # (32768, 5)

# --- Junk token classifier ---
_JUNK_TOKENS = {
    'dir', 'skip', 'color', 'display', 'hidden', 'font', 'heading',
    'end', 'max', 'left', 'right', 'top', 'begin', 'center',
    'title', 'view', 'path', 'track', 'hash', 'bg', 'pad',
    'stroke', 'figure', 'cite', 'star', 'screen', 'paper',
}

def is_junk_token(token_id):
    decoded = tokenizer.decode([token_id])
    if '\ufffd' in decoded:        return True   # byte-fallback
    if decoded.strip() == '':      return True   # whitespace/empty
    if '<|' in decoded:            return True   # special tokens
    stripped = decoded.strip()
    if stripped and not any(c.isalpha() for c in stripped):
        return True                              # non-alphabetic (punctuation, symbols, numbers)
    if stripped.lower() in _JUNK_TOKENS:
        return True                              # known code/markup tokens
    return False

# Classify each feature: junk if >=3 of top-5 tokens are junk
junk_counts = torch.zeros(W_dec_full.shape[0], dtype=torch.long)
for i in range(W_dec_full.shape[0]):
    junk_counts[i] = sum(is_junk_token(tid.item()) for tid in top5_ids[i])

is_junk_feature = junk_counts >= 3  # >=60% junk
keep_mask = ~is_junk_feature

n_total = W_dec_raw.shape[0]
n_kept = keep_mask.sum().item()
n_filtered = n_total - n_kept
print(f"Junk filter: {n_total} total → {n_kept} kept, {n_filtered} filtered ({n_filtered/n_total:.1%} junk)")

# Distribution of junk counts
for j in range(6):
    nj = (junk_counts == j).sum().item()
    print(f"  {j}/5 junk tokens: {nj} features")

# --- Apply filter, then subsample/shuffle ---
kept_indices = torch.where(keep_mask)[0].numpy()
W_dec_filtered = W_dec_raw[kept_indices]

rng = np.random.default_rng(RANDOM_SEED)
if QUICK_TEST:
    n_sample = min(1000, W_dec_filtered.shape[0])
    idx = rng.choice(W_dec_filtered.shape[0], size=n_sample, replace=False)
    # Map back to original feature indices
    idx_original = kept_indices[idx]
    W_dec = W_dec_filtered[idx]
    print(f"\nQuick test mode: subsampled to {W_dec.shape[0]} features (from {W_dec_filtered.shape[0]} non-junk)")
else:
    perm = rng.permutation(W_dec_filtered.shape[0])
    idx_original = kept_indices[perm]
    W_dec = W_dec_filtered[perm]
    print(f"\nShuffled {W_dec.shape[0]} non-junk features")

# idx tracks the original feature indices (for logit lens lookups in Cell 11)
idx = idx_original

In [None]:
# --- Preprocessing: optional PCA (no StandardScaler — W_dec rows are already unit-normalized) ---

# Diagnostic: confirm W_dec rows are unit-normalized
norms = np.linalg.norm(W_dec, axis=1)
print(f"W_dec L2 norms — min: {norms.min():.4f}, max: {norms.max():.4f}, "
      f"mean: {norms.mean():.4f}, std: {norms.std():.6f}")

if PCA_DIMS is not None and PCA_DIMS < W_dec.shape[1]:
    # PCA centers internally (subtracts mean), which is fine for unit-norm data
    pca = PCA(n_components=PCA_DIMS, random_state=RANDOM_SEED)
    X_reduced = pca.fit_transform(W_dec)
    explained = pca.explained_variance_ratio_.sum()
    print(f"PCA: {W_dec.shape[1]} → {PCA_DIMS} dims ({explained:.1%} variance explained)")

    # Variance diagnostic: dims needed for key thresholds
    cumvar = np.cumsum(pca.explained_variance_ratio_)
    for thresh in [0.80, 0.90, 0.95, 0.99]:
        n_dims = np.searchsorted(cumvar, thresh) + 1
        print(f"  Dims for {thresh:.0%} variance: {n_dims}")
else:
    X_reduced = W_dec.copy()
    print(f"No PCA applied. Using {X_reduced.shape[1]} dims.")

data_dim = X_reduced.shape[1]
print(f"Final data shape: {X_reduced.shape}")

In [None]:
# --- Create ART module ---
list_available_modules()
print()

model = create_art_module(MODULE_NAME, dim=data_dim, overrides=PARAM_OVERRIDES)
print(f"\nCreated {MODULE_NAME} for {data_dim}-dim data")
print(f"Parameters: {model.params}" if hasattr(model, 'params') else "")

In [None]:
# --- Flat clustering ---
X_prepared = model.prepare_data(X_reduced)
print(f"Data after prepare_data: shape {X_prepared.shape}")

model.fit(X_prepared, verbose=True)

labels = model.labels_
n_clusters = model.n_clusters
print(f"\nClusters found: {n_clusters}")
print(f"Samples: {len(labels)}")

# Cluster size stats
unique, counts = np.unique(labels, return_counts=True)
print(f"Cluster sizes — min: {counts.min()}, max: {counts.max()}, "
      f"median: {np.median(counts):.0f}, mean: {counts.mean():.1f}")

In [None]:
# --- UMAP visualization of flat clusters ---
import umap

reducer = umap.UMAP(n_components=2, random_state=RANDOM_SEED, n_neighbors=15)
embedding = reducer.fit_transform(X_reduced)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    embedding[:, 0], embedding[:, 1],
    c=labels, cmap="tab20", s=5, alpha=0.6
)
ax.set_title(f"{MODULE_NAME} Clusters (n={n_clusters}) — UMAP projection")
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
plt.colorbar(scatter, ax=ax, label="Cluster ID")
plt.tight_layout()
plt.show()

In [None]:
# --- Cluster size distribution ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].bar(range(len(counts)), sorted(counts, reverse=True), color="steelblue")
axes[0].set_xlabel("Cluster rank")
axes[0].set_ylabel("Size")
axes[0].set_title("Cluster sizes (sorted)")

axes[1].hist(counts, bins=30, color="steelblue", edgecolor="white")
axes[1].set_xlabel("Cluster size")
axes[1].set_ylabel("Count")
axes[1].set_title("Cluster size histogram")

plt.tight_layout()
plt.show()

In [None]:
# --- Hierarchical taxonomy with SMART ---
rho_values = [0.1, 0.4, 0.7]
smart = create_smart_model(
    MODULE_NAME, dim=data_dim, rho_values=rho_values, overrides=PARAM_OVERRIDES
)

X_smart = smart.prepare_data(X_reduced)
smart.fit(X_smart, verbose=True)

hierarchy = extract_hierarchy_labels(smart)
print(f"Hierarchy shape: {hierarchy.shape}  (samples × levels)")

for i, rho in enumerate(rho_values):
    n = len(np.unique(hierarchy[:, i]))
    print(f"  Level {i} (rho={rho}): {n} clusters")

In [None]:
# --- Differential logit lens (reuses LM + W_dec_full from filter cell) ---
# Subtract global mean direction so each cluster shows *distinctive* tokens,
# not the shared centroid (which produces the same top tokens for every cluster).

global_mean_dir = W_dec_full.mean(dim=0)


def get_cluster_tokens(feature_indices, top_k=10):
    """Differential logit lens: top tokens distinctive to this cluster."""
    dirs = W_dec_full[feature_indices]
    diff_dir = dirs.mean(dim=0) - global_mean_dir
    with torch.no_grad():
        normed = final_ln(diff_dir.unsqueeze(0))
        logits = (normed @ unembed_weights.T).squeeze(0)
    topk = logits.topk(top_k)
    return [tokenizer.decode([tid]) for tid in topk.indices.tolist()]


# --- Show all Level-0 clusters (broad themes) ---
n_levels = hierarchy.shape[1]
l0_ids = sorted(np.unique(hierarchy[:, 0]))

print("=" * 70)
print(f"LEVEL 0 — {len(l0_ids)} broad themes (rho={rho_values[0]})")
print("=" * 70)

for c in l0_ids:
    mask = hierarchy[:, 0] == c
    feature_indices = idx[mask]
    tokens = get_cluster_tokens(feature_indices, top_k=20)

    # Count children at next level
    child_info = ""
    if n_levels > 1:
        n_l1 = len(np.unique(hierarchy[mask, 1]))
        child_info = f", {n_l1} L1 children"

    print(f"\nL0:{c} ({mask.sum()} features{child_info}):")
    print(f"  {tokens}")

In [None]:
# --- Explore hierarchy: drill into any cluster's children ---

def explore_cluster(level, cluster_id, top_k=10, max_children=50):
    """Show a cluster's tokens, then list its children at the next level.

    Parameters
    ----------
    level : int
        Hierarchy level (0, 1, ...) of the cluster to expand.
    cluster_id : int
        Cluster ID at that level.
    top_k : int
        Number of logit-lens tokens to show per cluster.
    max_children : int
        Max children to display (largest first).
    """
    mask = hierarchy[:, level] == cluster_id
    n = mask.sum()
    if n == 0:
        print(f"No samples in L{level} cluster {cluster_id}")
        return

    feature_indices = idx[mask]
    tokens = get_cluster_tokens(feature_indices, top_k=top_k)
    print(f"L{level}:{cluster_id} ({n} features, rho={rho_values[level]}): {tokens}")

    child_level = level + 1
    if child_level >= hierarchy.shape[1]:
        print("  (finest level — no children)")
        return

    child_ids, child_counts = np.unique(hierarchy[mask, child_level], return_counts=True)
    order = np.argsort(-child_counts)  # largest first
    print(f"  -> {len(child_ids)} children at L{child_level} (rho={rho_values[child_level]}):\n")

    for rank, ci in enumerate(order[:max_children]):
        cid = child_ids[ci]
        child_mask = mask & (hierarchy[:, child_level] == cid)
        child_features = idx[child_mask]
        child_tokens = get_cluster_tokens(child_features, top_k=top_k)
        print(f"    L{child_level}:{cid} ({child_counts[ci]} features): {child_tokens}")

    if len(child_ids) > max_children:
        print(f"    ... and {len(child_ids) - max_children} more")


# Example: explore Level 0, Cluster 0 — change these to explore different parts
explore_cluster(level=0, cluster_id=0)

# To drill deeper, call again with a child:
# explore_cluster(level=1, cluster_id=<child_id_from_above>)

In [None]:
# --- Vigilance parameter sweep (optional/advanced) ---
# Sweep rho to see how cluster count changes with vigilance.
# Higher rho → more clusters (tighter match required, except BayesianART).

rho_sweep = np.linspace(0.1, 0.95, 10)
cluster_counts = []

for rho in tqdm(rho_sweep, desc="Vigilance sweep"):
    overrides_sweep = dict(PARAM_OVERRIDES, rho=rho)
    m = create_art_module(MODULE_NAME, dim=data_dim, overrides=overrides_sweep)
    X_p = m.prepare_data(X_reduced)
    m.fit(X_p)
    cluster_counts.append(m.n_clusters)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(rho_sweep, cluster_counts, "o-", color="steelblue")
ax.set_xlabel("Vigilance (rho)")
ax.set_ylabel("Number of clusters")
ax.set_title(f"Vigilance sweep — {MODULE_NAME}")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

for rho, nc in zip(rho_sweep, cluster_counts):
    print(f"  rho={rho:.2f} → {nc} clusters")

## Phase 2: Checkpoint Comparison (Stub)

**Goal:** Track how SAE feature categories evolve across Pythia-70M training checkpoints.

**Approach:**
1. Load Pythia-70M at multiple checkpoints (e.g., steps 1000, 10000, 50000, 143000)
2. Extract SAE decoder weights at each checkpoint
3. Use ART's `partial_fit()` to incrementally update the taxonomy
4. Track: resonant features (stable), new categories (plastic), dormant categories (lost)

This leverages ART's core strength — stability-plasticity balance — which SAEs lack entirely (they must retrain from scratch).

*Implementation in Phase 2.*