# Debug Feature Extraction

Step-by-step investigation of the extraction algorithm with k>1.

In [19]:
import torch
import sys
sys.path.insert(0, '..')

from src.config import SyntheticConfig, ExtractionConfig
from src.synthetic import generate_feature_basis, generate_representations
from src.extraction import (
    cluster_by_neighbors,
    resolve_tau,
    compute_nullspace,
    extract_feature,
    find_neighbors
)

torch.manual_seed(42)

<torch._C.Generator at 0x112e3f530>

## Configuration

In [44]:
syn_config = SyntheticConfig(
    d=16,
    n=16,
    epsilon=0.0,
    num_representations=1000,
    sparsity_mode="fixed",  # Options: "fixed", "bernoulli_gaussian"
    k=3,  # For bernoulli_gaussian: theta = k/n, E[||z||_0] = k
    coef_min_floor=0.3,  # Only used for non-BG modes
    positive_only=True,  # BG allows negative by default
)
ext_config = ExtractionConfig(
    epsilon=0.0,
)

print(f"d={syn_config.d}, n={syn_config.n}, k={syn_config.k}")
print(f"sparsity_mode={syn_config.sparsity_mode}")
print(f"tau={ext_config.tau}, neg_tau={ext_config.neg_tau}")

d=16, n=16, k=3
sparsity_mode=fixed
tau=None, neg_tau=None


## Generate Data

In [138]:
features = generate_feature_basis(syn_config.d, syn_config.n, syn_config.epsilon)
representations_old, coefficients_old = generate_representations(features, syn_config)

print(f"Features shape: {features.shape}")
print(f"Representations shape: {representations.shape}")
print(f"Coefficients shape: {coefficients.shape}")

Features shape: torch.Size([16, 16])
Representations shape: torch.Size([1002, 16])
Coefficients shape: torch.Size([1002, 16])


In [139]:
mono_coeff = torch.nn.functional.one_hot(torch.tensor(9), num_classes=syn_config.n).float()
mono_repr = mono_coeff @ features

In [140]:
representations = torch.vstack([representations_old, mono_repr])
coefficients = torch.vstack([coefficients_old, mono_coeff])

In [141]:
target_idx = -1

In [142]:
r, c = representations[target_idx], coefficients[target_idx]

In [145]:
def get_active_features(coef_row: torch.Tensor) -> set[int]:
    return set(torch.where(coef_row != 0)[0].tolist())

repr_features = [get_active_features(coefficients[i]) for i in range(representations.shape[0])]

# show first few
for i in range(5):
    print(f"repr {i}: active features = {repr_features[i]}")

repr 0: active features = {1, 11, 4}
repr 1: active features = {8, 13, 14}
repr 2: active features = {11, 3, 7}
repr 3: active features = {3, 6, 14}
repr 4: active features = {8, 9, 11}


In [146]:
get_active_features(c)

{9}

## Cosine Similarity Analysis

In [147]:
norms = torch.norm(representations, dim=1)
cos_sim_matrix = (representations @ representations.T) / (norms[:, None] * norms[None, :] + 1e-8)

# Pick a specific representation to analyze
target_idx = -1
target_features = repr_features[target_idx]
print(f"Target repr {target_idx} has features: {target_features}")

# Separate representations by whether they share any feature with target
shares_feature = [i for i in range(syn_config.num_representations) if repr_features[i] & target_features]
no_shared_feature = [i for i in range(syn_config.num_representations) if not (repr_features[i] & target_features)]

print(f"Shares feature: {len(shares_feature)}, No shared: {len(no_shared_feature)}")

Target repr -1 has features: {9}
Shares feature: 196, No shared: 804


In [148]:
# Cosine similarities to target
target_sims = cos_sim_matrix[target_idx]

sims_sharing = target_sims[shares_feature].abs()
sims_not_sharing = target_sims[no_shared_feature].abs()

print("Sharing a feature:")
print(f"  min={sims_sharing.min():.4f}, max={sims_sharing.max():.4f}, mean={sims_sharing.mean():.4f}")
print("Not sharing:")
print(f"  min={sims_not_sharing.min():.4f}, max={sims_not_sharing.max():.4f}, mean={sims_not_sharing.mean():.4f}")

Sharing a feature:
  min=0.2679, max=0.9114, mean=0.5441
Not sharing:
  min=0.0000, max=0.0000, mean=0.0000


## Clustering

In [149]:
tau = resolve_tau(ext_config, syn_config)
print(f"Using tau = {tau:.4f}")

clusters = cluster_by_neighbors(representations, tau)
print(f"Number of clusters: {len(clusters)}")

# Show cluster sizes
neighbor_sizes = [len(ns) for ns in clusters.keys()]
print(f"Neighbor set sizes: min={min(neighbor_sizes)}, max={max(neighbor_sizes)}, mean={sum(neighbor_sizes)/len(neighbor_sizes):.1f}")

Using tau = 0.0150
Number of clusters: 473
Neighbor set sizes: min=197, max=534, mean=489.5


In [187]:
neighbor_indices = find_neighbors(representations, target_idx, tau)
# target = representations[target_idx]

# norms = torch.norm(representations, dim=1)
# target_norm = norms[target_idx]

# # dots = representations @ target
# # cosine_sims = dots / (norms * target_norm + 1e-8)

# # # Find neighbors using absolute cosine similarity
# # # This catches both aligned (+) and anti-aligned (-) representations
# # neighbor_mask = torch.abs(cosine_sims) >= tau
# # neighbor_indices = torch.where(neighbor_mask)[0]
neighbor_indices_items = [i.item() for i in neighbor_indices]
len(neighbor_indices)

197

## Analyze Single Cluster

In [188]:
neighbor_set = set(neighbor_indices_items)
pos_indices = neighbor_indices_items
neg_indices = [i for i in range(representations.shape[0]) if i not in neighbor_set]

print(f"Positive set size: {len(pos_indices)}")
print(f"Negative set size: {len(neg_indices)}")

Positive set size: 197
Negative set size: 804


In [189]:
# Count features in positive set
feature_counts = {}
for i in pos_indices:
    for f in repr_features[i]:
        feature_counts[f] = feature_counts.get(f, 0) + 1

sorted_features = sorted(feature_counts.items(), key=lambda x: -x[1])
print("Feature counts in positive set:")
for f, count in sorted_features:
    print(f"  Feature {f}: {count}/{len(pos_indices)} ({100*count/len(pos_indices):.0f}%)")

Feature counts in positive set:
  Feature 9: 197/197 (100%)
  Feature 10: 36/197 (18%)
  Feature 13: 32/197 (16%)
  Feature 5: 30/197 (15%)
  Feature 8: 28/197 (14%)
  Feature 4: 28/197 (14%)
  Feature 11: 27/197 (14%)
  Feature 1: 27/197 (14%)
  Feature 7: 26/197 (13%)
  Feature 2: 26/197 (13%)
  Feature 3: 26/197 (13%)
  Feature 15: 23/197 (12%)
  Feature 14: 22/197 (11%)
  Feature 0: 21/197 (11%)
  Feature 12: 21/197 (11%)
  Feature 6: 19/197 (10%)


## Nullspace Computation

In [190]:
nullspace = compute_nullspace(representations, neighbor_indices, 0.0, neg_tau=ext_config.neg_tau, verbose=True)

11:13:37 | INFO | Nullspace computation:
  n_neg=804, rms_norm=1.1916
  epsilon_tilde = sqrt(804) * 1.1916 * 0.0 = 0.000001
  singular values below threshold: 1
  first SV below threshold: 0.000001


In [191]:
nullspace

tensor([[-0.1894,  0.3829,  0.2578, -0.2098,  0.3358,  0.1459,  0.1695,  0.1030,
          0.4089,  0.2581, -0.0641, -0.1452,  0.4965, -0.0752,  0.0037, -0.1495]])

In [192]:
target_features

{9}

In [193]:
for i, feature in enumerate(features):
    coeffs = nullspace @ feature
    is_in_nullspace = torch.isclose((coeffs ** 2).sum(), torch.tensor(0.0))
    if is_in_nullspace:
        print(f"Feautre {i} is in nullspace")
    else:
        print(f"Feautre {i} is NOT in nullspace")

Feautre 0 is in nullspace
Feautre 1 is in nullspace
Feautre 2 is in nullspace
Feautre 3 is in nullspace
Feautre 4 is in nullspace
Feautre 5 is in nullspace
Feautre 6 is in nullspace
Feautre 7 is in nullspace
Feautre 8 is in nullspace
Feautre 9 is NOT in nullspace
Feautre 10 is in nullspace
Feautre 11 is in nullspace
Feautre 12 is in nullspace
Feautre 13 is in nullspace
Feautre 14 is in nullspace
Feautre 15 is in nullspace


## Feature Extraction

In [194]:
if nullspace.shape[0] > 0:
    extracted = extract_feature(representations, neighbor_indices, nullspace, verbose=True)
    
    # Check alignment with ground truth features
    alignments = torch.abs(extracted @ features.T)
    best_match = alignments.argmax().item()
    best_alignment = alignments[best_match].item()
    
    print(f"Best match: feature {best_match} with alignment {best_alignment:.4f}")
    print(f"Top 5 alignments: {alignments.topk(5)}")
else:
    print("Empty nullspace!")

11:13:55 | INFO | Feature extraction:
  n_neighbors=197
  first SV=9.4144, second SV=0.0000
  ratio=15029061.49


Best match: feature 9 with alignment 1.0000
Top 5 alignments: torch.return_types.topk(
values=tensor([1.0000e+00, 1.0431e-07, 7.4506e-08, 5.7742e-08, 5.4017e-08]),
indices=tensor([ 9,  6, 12,  0, 15]))


## Manual Investigation

Add your own debugging cells below.

In [195]:
if nullspace.shape[0] == 0:
    raise ValueError("Nullspace is empty, cannot extract feature")

# Get neighbor representations: (m, d) - rows are vectors
neighbors = representations[neighbor_indices]

# Project ALL neighbors onto nullspace
# For row vectors: projected = X @ nullspace.T @ nullspace
# This applies the projection matrix P = nullspace.T @ nullspace to each row
projected = neighbors @ nullspace.T @ nullspace  # (m, d)

# SVD to find direction of maximum variance
# projected = U @ diag(S) @ Vh
# Vh[0] (first row) is the dominant direction in d-space
_, S, Vh = torch.linalg.svd(projected, full_matrices=False)

# First right singular vector is the feature (already unit norm)
feature = Vh[0]  # (d,)


In [196]:
for s in S:
    print(s)

tensor(9.4144)
tensor(6.1641e-07)
tensor(1.0047e-07)
tensor(9.0771e-08)
tensor(8.1612e-08)
tensor(7.0151e-08)
tensor(6.1893e-08)
tensor(5.4658e-08)
tensor(4.2041e-08)
tensor(3.9400e-08)
tensor(3.5088e-08)
tensor(3.3267e-08)
tensor(2.3394e-08)
tensor(1.9129e-08)
tensor(1.3865e-08)
tensor(8.1618e-10)
