# $\epsilon$-Orthogonal Feature Basis

In [49]:
import torch
import sys
sys.path.insert(0, '..')

from collections import Counter
from src.synthetic import generate_feature_basis, generate_representations
from src.config import SyntheticConfig, ExtractionConfig
from src.extraction import resolve_tau, compute_nullspace, build_neighbor_matrix, find_monosemantic_targets
from src.metrics import match_features

torch.manual_seed(67)

<torch._C.Generator at 0x11461f550>

In [50]:
syn_config = SyntheticConfig(
    d=100,
    n=110,
    num_representations=5000,
    sparsity_mode="bernoulli_gaussian",
    k=6,
    positive_only=True,
)

## Generate $\epsilon$-orthogonal feature frame

In [51]:
result = generate_feature_basis(syn_config.d, syn_config.n)
ext_config = ExtractionConfig(epsilon=result.achieved_epsilon)
print(f"Feature frame: {syn_config.n} features in {syn_config.d} dimensions, coherence={result.achieved_epsilon:.4f}")

Feature frame: 110 features in 100 dimensions, coherence=0.0370


In [52]:
reps, coeffs = generate_representations(result.features, syn_config)

## Ground truth: which representations are truly monosemantic?

In [53]:
true_mono_indices = set()
mono_features = set()
for i, coeff in enumerate(coeffs):
    active = (coeff != 0)
    if active.sum() == 1:
        true_mono_indices.add(i)
        mono_features.add(torch.nonzero(active).flatten()[0].item())

print(f"{len(true_mono_indices)} monosemantic representations (exactly 1 active feature)")
print(f"{len(mono_features)}/{syn_config.n} unique features appear monosemantically")

84 monosemantic representations (exactly 1 active feature)
61/110 unique features appear monosemantically


## Evaluate `find_monosemantic_targets`

In [54]:
tau = resolve_tau(ext_config, syn_config, epsilon=result.achieved_epsilon)
neighbor_matrix = build_neighbor_matrix(reps, tau)
neighbor_counts = neighbor_matrix.sum(dim=1)
target_indices = find_monosemantic_targets(neighbor_matrix)
target_set = set(target_indices.tolist())

print(f"tau = {tau:.6f}")
print(f"Targets returned: {len(target_set)}")

tau = 0.522555
Targets returned: 923


In [55]:
true_positives = true_mono_indices & target_set
false_negatives = true_mono_indices - target_set
false_positives = target_set - true_mono_indices

print(f"True positives  (mono & target):     {len(true_positives)}")
print(f"False negatives (mono, not target):   {len(false_negatives)}")
print(f"False positives (target, not mono):   {len(false_positives)}")
print(f"Recall: {len(true_positives)/len(true_mono_indices):.2%}")

True positives  (mono & target):     0
False negatives (mono, not target):   84
False positives (target, not mono):   923
Recall: 0.00%


### Diagnosing the minimality filter

In [56]:
# Why are truly monosemantic representations failing the minimality filter?
if false_negatives:
    print(f"{len(false_negatives)} truly mono representations were NOT selected as targets:")
    for idx in sorted(list(false_negatives))[:10]:
        my_count = neighbor_counts[idx].item()
        neighbors = neighbor_matrix[idx]
        min_neighbor = neighbor_counts[neighbors].min().item()
        print(f"  Rep {idx}: {int(my_count)} neighbors, min in neighborhood: {int(min_neighbor)}")
else:
    print("No false negatives")

84 truly mono representations were NOT selected as targets:
  Rep 105: 70 neighbors, min in neighborhood: 12
  Rep 149: 61 neighbors, min in neighborhood: 12
  Rep 173: 62 neighbors, min in neighborhood: 12
  Rep 245: 61 neighbors, min in neighborhood: 8
  Rep 249: 60 neighbors, min in neighborhood: 6
  Rep 274: 53 neighbors, min in neighborhood: 6
  Rep 301: 48 neighbors, min in neighborhood: 9
  Rep 348: 61 neighbors, min in neighborhood: 8
  Rep 473: 67 neighbors, min in neighborhood: 13
  Rep 478: 62 neighbors, min in neighborhood: 8


In [None]:
# For each false negative, find the min-neighbor rep in its neighborhood
# and examine its relationship to the monosemantic feature direction
import torch.nn.functional as F

features = result.features  # (n, d) ground truth feature basis

print("For each mono rep: what is the min-neighbor rep, and how does it relate to the mono feature?\n")
for idx in sorted(list(false_negatives))[:10]:
    mono_coeff = coeffs[idx]
    mono_feat_idx = torch.nonzero(mono_coeff != 0).flatten()[0].item()
    mono_feat = features[mono_feat_idx]  # the true feature direction
    mono_coeff_val = mono_coeff[mono_feat_idx].item()

    # Find the neighbor with the minimum neighbor count
    neighbors = torch.where(neighbor_matrix[idx])[0]
    min_count_in_neighborhood = neighbor_counts[neighbors].min().item()
    min_neighbors = neighbors[neighbor_counts[neighbors] == min_count_in_neighborhood]
    min_rep_idx = min_neighbors[0].item()

    # Cosine similarity between the min-neighbor rep and the mono feature direction
    min_rep = reps[min_rep_idx]
    cossim_with_feat = F.cosine_similarity(min_rep.unsqueeze(0), mono_feat.unsqueeze(0)).item()

    # Cosine similarity between the mono rep and the min-neighbor rep
    cossim_mono_min = F.cosine_similarity(reps[idx].unsqueeze(0), min_rep.unsqueeze(0)).item()

    # What features are active in the min-neighbor rep?
    min_coeff = coeffs[min_rep_idx]
    min_active = torch.nonzero(min_coeff != 0).flatten().tolist()
    has_mono_feat = mono_feat_idx in min_active

    # If it has the mono feature, what's its coefficient vs others?
    min_coeff_for_mono = min_coeff[mono_feat_idx].item() if has_mono_feat else 0.0
    min_active_vals = min_coeff[min_coeff != 0]
    min_max_coeff = min_active_vals.abs().max().item()

    print(f"Mono rep {idx} (feature {mono_feat_idx}, coeff={mono_coeff_val:.3f}, {int(neighbor_counts[idx])} neighbors):")
    print(f"  Min-neighbor rep {min_rep_idx}: {int(min_count_in_neighborhood)} neighbors, {len(min_active)} active features")
    print(f"  cossim(min_rep, mono_feature_dir) = {cossim_with_feat:.4f}   (tau = {tau:.4f})")
    print(f"  cossim(mono_rep, min_rep)          = {cossim_mono_min:.4f}")
    print(f"  Has mono feature #{mono_feat_idx}? {has_mono_feat}", end="")
    if has_mono_feat:
        print(f"  (coeff={min_coeff_for_mono:.3f}, max_coeff={min_max_coeff:.3f}, ratio={min_coeff_for_mono/min_max_coeff:.2f})")
    else:
        print(f"  â€” connected via cross-feature interference alone")
    print()

In [57]:
# What do the false positives look like?
fp_num_features = []
fp_dominance_ratios = []

for idx in false_positives:
    c = coeffs[idx]
    active = c[c != 0]
    fp_num_features.append(len(active))
    if len(active) >= 2:
        sorted_abs = torch.sort(active.abs(), descending=True).values
        fp_dominance_ratios.append((sorted_abs[0] / sorted_abs[1]).item())

print("False positive active feature counts:")
for k, v in sorted(Counter(fp_num_features).items()):
    print(f"  {k} active features: {v} ({v/len(false_positives):.1%})")

if fp_dominance_ratios:
    ratios = torch.tensor(fp_dominance_ratios)
    print(f"\nDominance ratio (max_coeff / 2nd_coeff) for multi-feature FPs:")
    print(f"  Mean={ratios.mean():.2f}, Median={ratios.median():.2f}, Min={ratios.min():.2f}, Max={ratios.max():.2f}")
    print(f"  Ratio > 3 (one dominates): {(ratios > 3).sum().item()}/{len(ratios)}")
    print(f"  Ratio < 1.5 (no dominant):  {(ratios < 1.5).sum().item()}/{len(ratios)}")

False positive active feature counts:
  3 active features: 2 (0.2%)
  4 active features: 32 (3.5%)
  5 active features: 66 (7.2%)
  6 active features: 141 (15.3%)
  7 active features: 159 (17.2%)
  8 active features: 158 (17.1%)
  9 active features: 160 (17.3%)
  10 active features: 91 (9.9%)
  11 active features: 63 (6.8%)
  12 active features: 29 (3.1%)
  13 active features: 15 (1.6%)
  14 active features: 3 (0.3%)
  15 active features: 3 (0.3%)
  16 active features: 1 (0.1%)

Dominance ratio (max_coeff / 2nd_coeff) for multi-feature FPs:
  Mean=1.19, Median=1.15, Min=1.00, Max=1.98
  Ratio > 3 (one dominates): 0/923
  Ratio < 1.5 (no dominant):  877/923


In [58]:
# Neighbor count distributions
print(f"All representations:  mean={neighbor_counts.float().mean():.1f}, median={neighbor_counts.float().median():.1f}")
print(f"Targets:             mean={neighbor_counts[target_indices].float().mean():.1f}, median={neighbor_counts[target_indices].float().median():.1f}")
if true_mono_indices:
    mono_counts = neighbor_counts[list(true_mono_indices)].float()
    print(f"Truly monosemantic:  mean={mono_counts.mean():.1f}, median={mono_counts.median():.1f}")

All representations:  mean=30.7, median=30.0
Targets:             mean=12.0, median=11.0
Truly monosemantic:  mean=59.0, median=61.0


## Post-hoc spectral gap filtering

Instead of pre-selecting targets via neighbor count minimality, run extraction on
ALL unique neighbor-set representatives and use the SVD singular value spectrum to
identify monosemantic ones: a large gap between $\sigma_1$ and $\sigma_2$ indicates
a single dominant feature direction.

In [None]:
from tqdm.notebook import tqdm

epsilon = ext_config.epsilon

# Get ALL unique neighbor-set representatives (no minimality filter)
seen_rows: dict[bytes, int] = {}
all_representatives: list[int] = []
for i in range(neighbor_matrix.shape[0]):
    row_key = neighbor_matrix[i].cpu().numpy().tobytes()
    if row_key not in seen_rows:
        seen_rows[row_key] = i
        all_representatives.append(i)

print(f"Total unique neighbor sets: {len(all_representatives)}")

# Run extraction on each, collecting singular value spectra
results = []

for idx in tqdm(all_representatives, desc="Extracting"):
    neighbor_mask = neighbor_matrix[idx]
    neighbor_indices = torch.where(neighbor_mask)[0]

    if len(neighbor_indices) < 2:
        continue

    try:
        nullspace = compute_nullspace(reps, neighbor_indices, epsilon)
        if nullspace.shape[0] == 0:
            continue

        neighbors = reps[neighbor_indices]
        projected = neighbors @ nullspace.T @ nullspace

        _, S, Vh = torch.linalg.svd(projected, full_matrices=False)
        feature = Vh[0]

        num_true = (coeffs[idx] != 0).sum().item()
        gap = (S[0] / S[1]).item() if len(S) > 1 and S[1] > 1e-10 else float('inf')

        results.append({
            'idx': idx,
            'feature': feature,
            'singular_values': S,
            'num_true_features': num_true,
            'is_mono': num_true == 1,
            'spectral_gap': gap,
        })
    except (ValueError, RuntimeError):
        continue

print(f"Successfully extracted from {len(results)} representatives")

In [None]:
# Spectral gap distributions: monosemantic vs multi-feature
mono_gaps = [r['spectral_gap'] for r in results if r['is_mono']]
poly_gaps = [r['spectral_gap'] for r in results if not r['is_mono']]

mono_finite = torch.tensor([g for g in mono_gaps if g != float('inf')])
poly_finite = torch.tensor([g for g in poly_gaps if g != float('inf')])

print(f"Monosemantic ({len(mono_gaps)} reps, {sum(1 for g in mono_gaps if g == float('inf'))} with inf gap):")
if len(mono_finite) > 0:
    print(f"  Mean={mono_finite.mean():.2f}, Median={mono_finite.median():.2f}, Min={mono_finite.min():.2f}, Max={mono_finite.max():.2f}")

print(f"\nMulti-feature ({len(poly_gaps)} reps, {sum(1 for g in poly_gaps if g == float('inf'))} with inf gap):")
if len(poly_finite) > 0:
    print(f"  Mean={poly_finite.mean():.2f}, Median={poly_finite.median():.2f}, Min={poly_finite.min():.2f}, Max={poly_finite.max():.2f}")

In [None]:
# Threshold sweep: precision/recall tradeoff for spectral gap filtering
total_mono = sum(1 for r in results if r['is_mono'])

print(f"{'Threshold':>10} {'Selected':>8} {'True Mono':>10} {'Precision':>10} {'Recall':>8}")
for threshold in [1.5, 2.0, 3.0, 5.0, 8.0, 10.0, 15.0, 20.0]:
    selected = [r for r in results if r['spectral_gap'] >= threshold]
    selected_mono = [r for r in selected if r['is_mono']]
    precision = len(selected_mono) / len(selected) if selected else 0
    recall = len(selected_mono) / total_mono if total_mono > 0 else 0
    print(f"{threshold:>10.1f} {len(selected):>8} {len(selected_mono):>10} {precision:>10.1%} {recall:>8.1%}")

In [None]:
# Do high-spectral-gap extractions (even from multi-feature reps) recover true features?
SPECTRAL_THRESHOLD = 5.0

high_gap = [r for r in results if r['spectral_gap'] >= SPECTRAL_THRESHOLD]
if high_gap:
    extracted = torch.stack([r['feature'] for r in high_gap])
    matching, scores = match_features(extracted, result.features, threshold=0.9)

    print(f"High spectral gap (>={SPECTRAL_THRESHOLD}) extractions: {len(high_gap)}")
    print(f"  Matched to a true feature (|cossim| > 0.9): {len(matching)}")
    if len(matching) > 0:
        matched_scores = scores[scores > 0]
        print(f"  Mean alignment of matches: {matched_scores.mean():.3f}")

    mono_high = [r for r in high_gap if r['is_mono']]
    poly_high = [r for r in high_gap if not r['is_mono']]
    print(f"  From monosemantic reps: {len(mono_high)}")
    print(f"  From multi-feature reps: {len(poly_high)}")

    if poly_high:
        poly_extracted = torch.stack([r['feature'] for r in poly_high])
        poly_matching, poly_scores = match_features(poly_extracted, result.features, threshold=0.9)
        print(f"  Multi-feature extractions matching a true feature: {len(poly_matching)}/{len(poly_high)}")