# ART-Organized SAE Feature Taxonomy

Cluster pre-trained SAE decoder weight vectors from **Pythia-70M** using Adaptive Resonance Theory (ART) modules to build hierarchical feature taxonomies.

**Pipeline:** SAE W_dec (32K × 512) → PCA → ART clustering → hierarchical taxonomy via SMART

**Key idea:** ART's vigilance parameter provides a single, principled knob for controlling cluster granularity — something SAEs lack. SMART builds a hierarchy by stacking multiple vigilance levels.

In [None]:
# --- Setup ---
import sys, os

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # Clone or pull the repo, then cd into it
    if os.path.exists("art_utils.py"):
        # Already inside the repo (cell re-run)
        !git pull
    elif os.path.exists("ART-SAE-taxonomy"):
        os.chdir("ART-SAE-taxonomy")
        !git pull
    else:
        !git clone https://github.com/syre-ai/ART-SAE-taxonomy.git
        os.chdir("ART-SAE-taxonomy")
    !pip install -q artlib==0.1.7 eai-sparsify transformers torch umap-learn plotly tqdm pandas

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

from art_utils import (
    create_art_module,
    create_smart_model,
    extract_hierarchy_labels,
    list_available_modules,
)

print("Setup complete.")

In [None]:
# ============================================================
# CONFIGURATION — edit this cell to change experiments
# ============================================================

LAYER_IDX = 3              # Pythia-70M layer (0-5)
MODULE_NAME = "GPUFuzzyART"  # GPU-accelerated FuzzyART (see list_available_modules())
PARAM_OVERRIDES = {}       # e.g. {"rho": 0.8}
PCA_DIMS = 50              # Reduce 512 dims to this; set None to skip PCA
QUICK_TEST = True          # Subsample to 1000 features for fast iteration
RANDOM_SEED = 42

In [None]:
# --- Load SAE decoder weights ---
from sparsify import Sae

saes = Sae.load_many("EleutherAI/sae-pythia-70m-32k")
layer_key = f"layers.{LAYER_IDX}"
W_dec = saes[layer_key].W_dec.detach().numpy()  # (32768, 512)
print(f"Loaded W_dec for {layer_key}: shape {W_dec.shape}")

# Optional subsample for quick testing
rng = np.random.default_rng(RANDOM_SEED)
if QUICK_TEST:
    idx = rng.choice(W_dec.shape[0], size=1000, replace=False)
    W_dec = W_dec[idx]
    print(f"Quick test mode: subsampled to {W_dec.shape[0]} features")
else:
    # Shuffle to reduce ordering effects in ART
    idx = rng.permutation(W_dec.shape[0])
    W_dec = W_dec[idx]
    print(f"Shuffled {W_dec.shape[0]} features")

In [None]:
# --- Preprocessing: StandardScaler + optional PCA ---
scaler = StandardScaler()
X_scaled = scaler.fit_transform(W_dec)

if PCA_DIMS is not None and PCA_DIMS < X_scaled.shape[1]:
    pca = PCA(n_components=PCA_DIMS, random_state=RANDOM_SEED)
    X_reduced = pca.fit_transform(X_scaled)
    explained = pca.explained_variance_ratio_.sum()
    print(f"PCA: {X_scaled.shape[1]} → {PCA_DIMS} dims ({explained:.1%} variance explained)")
else:
    X_reduced = X_scaled
    print(f"No PCA applied. Using {X_reduced.shape[1]} dims.")

data_dim = X_reduced.shape[1]
print(f"Final data shape: {X_reduced.shape}")

In [None]:
# --- Create ART module ---
list_available_modules()
print()

model = create_art_module(MODULE_NAME, dim=data_dim, overrides=PARAM_OVERRIDES)
print(f"\nCreated {MODULE_NAME} for {data_dim}-dim data")
print(f"Parameters: {model.params}" if hasattr(model, 'params') else "")

In [None]:
# --- Flat clustering ---
X_prepared = model.prepare_data(X_reduced)
print(f"Data after prepare_data: shape {X_prepared.shape}")

model.fit(X_prepared, verbose=True)

labels = model.labels_
n_clusters = model.n_clusters
print(f"\nClusters found: {n_clusters}")
print(f"Samples: {len(labels)}")

# Cluster size stats
unique, counts = np.unique(labels, return_counts=True)
print(f"Cluster sizes — min: {counts.min()}, max: {counts.max()}, "
      f"median: {np.median(counts):.0f}, mean: {counts.mean():.1f}")

In [None]:
# --- UMAP visualization of flat clusters ---
import umap

reducer = umap.UMAP(n_components=2, random_state=RANDOM_SEED, n_neighbors=15)
embedding = reducer.fit_transform(X_reduced)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    embedding[:, 0], embedding[:, 1],
    c=labels, cmap="tab20", s=5, alpha=0.6
)
ax.set_title(f"{MODULE_NAME} Clusters (n={n_clusters}) — UMAP projection")
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
plt.colorbar(scatter, ax=ax, label="Cluster ID")
plt.tight_layout()
plt.show()

In [None]:
# --- Cluster size distribution ---
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].bar(range(len(counts)), sorted(counts, reverse=True), color="steelblue")
axes[0].set_xlabel("Cluster rank")
axes[0].set_ylabel("Size")
axes[0].set_title("Cluster sizes (sorted)")

axes[1].hist(counts, bins=30, color="steelblue", edgecolor="white")
axes[1].set_xlabel("Cluster size")
axes[1].set_ylabel("Count")
axes[1].set_title("Cluster size histogram")

plt.tight_layout()
plt.show()

In [None]:
# --- Hierarchical taxonomy with SMART ---
rho_values = [0.3, 0.6, 0.85]
smart = create_smart_model(
    MODULE_NAME, dim=data_dim, rho_values=rho_values, overrides=PARAM_OVERRIDES
)

X_smart = smart.prepare_data(X_reduced)
smart.fit(X_smart, verbose=True)

hierarchy = extract_hierarchy_labels(smart)
print(f"Hierarchy shape: {hierarchy.shape}  (samples × levels)")

for i, rho in enumerate(rho_values):
    n = len(np.unique(hierarchy[:, i]))
    print(f"  Level {i} (rho={rho}): {n} clusters")

In [None]:
# --- Sunburst visualization of hierarchy ---
level_names = [f"L{i}_rho{r}" for i, r in enumerate(rho_values)]

df_hier = pd.DataFrame(hierarchy, columns=level_names)
# Convert cluster IDs to strings for plotly
for col in level_names:
    df_hier[col] = df_hier[col].astype(str)
df_hier["count"] = 1

fig = px.sunburst(
    df_hier,
    path=level_names,
    values="count",
    title=f"SMART Hierarchy — {MODULE_NAME} base, rho={rho_values}",
)
fig.update_layout(width=700, height=700)
fig.show()

In [None]:
# --- Feature interpretation: top-activating tokens per cluster ---
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
lm = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
lm.eval()

# Get the embedding matrix (vocab_size × d_model)
embed_weights = lm.gpt_neox.embed_in.weight.detach()  # (50304, 512)

# For each cluster, find tokens whose embeddings most align with the cluster's features
# Use the original (non-PCA) decoder directions for interpretability
sae = saes[f"layers.{LAYER_IDX}"]
W_dec_full = sae.W_dec.detach()  # (32768, 512) on CPU

n_show = min(5, n_clusters)  # Show top tokens for first few clusters
top_k_tokens = 10

print(f"Top-{top_k_tokens} tokens for first {n_show} clusters:\n")
for cluster_id in range(n_show):
    mask = labels == cluster_id
    if mask.sum() == 0:
        continue

    # Original feature indices (before subsampling/shuffling)
    feature_indices = idx[mask] if QUICK_TEST else np.where(mask)[0]
    cluster_dirs = W_dec_full[feature_indices]  # (n_features_in_cluster, 512)

    # Mean direction of cluster
    mean_dir = cluster_dirs.mean(dim=0)
    mean_dir = mean_dir / mean_dir.norm()

    # Cosine similarity with all token embeddings
    embed_normed = embed_weights / embed_weights.norm(dim=1, keepdim=True).clamp(min=1e-8)
    sims = embed_normed @ mean_dir
    topk = sims.topk(top_k_tokens)

    tokens = [tokenizer.decode([tid]) for tid in topk.indices.tolist()]
    print(f"Cluster {cluster_id} ({mask.sum()} features): {tokens}")

In [None]:
# --- Vigilance parameter sweep (optional/advanced) ---
# Sweep rho to see how cluster count changes with vigilance.
# Higher rho → more clusters (tighter match required, except BayesianART).

rho_sweep = np.linspace(0.1, 0.95, 10)
cluster_counts = []

for rho in tqdm(rho_sweep, desc="Vigilance sweep"):
    overrides_sweep = dict(PARAM_OVERRIDES, rho=rho)
    m = create_art_module(MODULE_NAME, dim=data_dim, overrides=overrides_sweep)
    X_p = m.prepare_data(X_reduced)
    m.fit(X_p)
    cluster_counts.append(m.n_clusters)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(rho_sweep, cluster_counts, "o-", color="steelblue")
ax.set_xlabel("Vigilance (rho)")
ax.set_ylabel("Number of clusters")
ax.set_title(f"Vigilance sweep — {MODULE_NAME}")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

for rho, nc in zip(rho_sweep, cluster_counts):
    print(f"  rho={rho:.2f} → {nc} clusters")

## Phase 2: Checkpoint Comparison (Stub)

**Goal:** Track how SAE feature categories evolve across Pythia-70M training checkpoints.

**Approach:**
1. Load Pythia-70M at multiple checkpoints (e.g., steps 1000, 10000, 50000, 143000)
2. Extract SAE decoder weights at each checkpoint
3. Use ART's `partial_fit()` to incrementally update the taxonomy
4. Track: resonant features (stable), new categories (plastic), dormant categories (lost)

This leverages ART's core strength — stability-plasticity balance — which SAEs lack entirely (they must retrain from scratch).

*Implementation in Phase 2.*