# Debug Feature Extraction

Step-by-step investigation of the extraction algorithm with k>1.

In [None]:
import torch
import sys
sys.path.insert(0, '..')

from src.config import SyntheticConfig, ExtractionConfig
from src.synthetic import generate_feature_basis, generate_representations
from src.extraction import (
    cluster_by_neighbors,
    resolve_tau,
    compute_nullspace,
    extract_feature,
)

torch.manual_seed(42)

## Configuration

In [None]:
syn_config = SyntheticConfig(
    d=16,
    n=16,
    epsilon=0.0,
    num_representations=1000,
    sparsity_mode="bernoulli_gaussian",  # Options: "fixed", "bernoulli_gaussian"
    k=3,  # For bernoulli_gaussian: theta = k/n, E[||z||_0] = k
    coef_min_floor=0.5,  # Only used for non-BG modes
    positive_only=False,  # BG allows negative by default
)
ext_config = ExtractionConfig(
    tau=0.5,
    neg_tau=0.05,
    epsilon=0.0,
)

print(f"d={syn_config.d}, n={syn_config.n}, k={syn_config.k}")
print(f"sparsity_mode={syn_config.sparsity_mode}")
print(f"tau={ext_config.tau}, neg_tau={ext_config.neg_tau}")

## Generate Data

In [None]:
features = generate_feature_basis(syn_config.d, syn_config.n, syn_config.epsilon)
representations, coefficients = generate_representations(features, syn_config)

print(f"Features shape: {features.shape}")
print(f"Representations shape: {representations.shape}")
print(f"Coefficients shape: {coefficients.shape}")

In [None]:
def get_active_features(coef_row: torch.Tensor) -> set[int]:
    return set(torch.where(coef_row != 0)[0].tolist())

repr_features = [get_active_features(coefficients[i]) for i in range(syn_config.num_representations)]

# Show first few
for i in range(5):
    print(f"Repr {i}: active features = {repr_features[i]}")

## Cosine Similarity Analysis

In [None]:
norms = torch.norm(representations, dim=1)
cos_sim_matrix = (representations @ representations.T) / (norms[:, None] * norms[None, :] + 1e-8)

# Pick a specific representation to analyze
target_idx = 0
target_features = repr_features[target_idx]
print(f"Target repr {target_idx} has features: {target_features}")

# Separate representations by whether they share any feature with target
shares_feature = [i for i in range(syn_config.num_representations) if repr_features[i] & target_features]
no_shared_feature = [i for i in range(syn_config.num_representations) if not (repr_features[i] & target_features)]

print(f"Shares feature: {len(shares_feature)}, No shared: {len(no_shared_feature)}")

In [None]:
# Cosine similarities to target
target_sims = cos_sim_matrix[target_idx]

sims_sharing = target_sims[shares_feature].abs()
sims_not_sharing = target_sims[no_shared_feature].abs()

print("Sharing a feature:")
print(f"  min={sims_sharing.min():.4f}, max={sims_sharing.max():.4f}, mean={sims_sharing.mean():.4f}")
print("Not sharing:")
print(f"  min={sims_not_sharing.min():.4f}, max={sims_not_sharing.max():.4f}, mean={sims_not_sharing.mean():.4f}")

## Clustering

In [None]:
tau = resolve_tau(ext_config, syn_config)
print(f"Using tau = {tau:.4f}")

clusters = cluster_by_neighbors(representations, tau)
print(f"Number of clusters: {len(clusters)}")

# Show cluster sizes
neighbor_sizes = [len(ns) for ns in clusters.keys()]
print(f"Neighbor set sizes: min={min(neighbor_sizes)}, max={max(neighbor_sizes)}, mean={sum(neighbor_sizes)/len(neighbor_sizes):.1f}")

## Analyze Single Cluster

In [None]:
# Pick first cluster with >1 member
cluster_list = [(ns, members) for ns, members in clusters.items() if len(ns) >= 2]
neighbor_set, cluster_members = cluster_list[0]

pos_indices = list(neighbor_set)
neg_indices = [i for i in range(syn_config.num_representations) if i not in neighbor_set]

print(f"Positive set size: {len(pos_indices)}")
print(f"Negative set size: {len(neg_indices)}")

In [None]:
# Count features in positive set
feature_counts = {}
for i in pos_indices:
    for f in repr_features[i]:
        feature_counts[f] = feature_counts.get(f, 0) + 1

sorted_features = sorted(feature_counts.items(), key=lambda x: -x[1])
print("Feature counts in positive set:")
for f, count in sorted_features:
    print(f"  Feature {f}: {count}/{len(pos_indices)} ({100*count/len(pos_indices):.0f}%)")

## Nullspace Computation

In [None]:
neighbor_indices = torch.tensor(pos_indices)

# Filter non-neighbors by neg_tau
neighbors = representations[neighbor_indices]
neighbor_mean = neighbors.mean(dim=0)
neighbor_mean_norm = torch.norm(neighbor_mean)

all_norms = torch.norm(representations, dim=1)
all_dots = representations @ neighbor_mean
all_cosine_sims = torch.abs(all_dots / (all_norms * neighbor_mean_norm + 1e-8))

# Count how many non-neighbors pass neg_tau filter
neg_tau = ext_config.neg_tau
filtered_neg_indices = [i for i in neg_indices if all_cosine_sims[i] <= neg_tau]
print(f"Non-neighbors passing neg_tau={neg_tau}: {len(filtered_neg_indices)}/{len(neg_indices)}")

In [None]:
nullspace = compute_nullspace(representations, neighbor_indices, 0.0, neg_tau=ext_config.neg_tau, verbose=True)
print(f"Nullspace shape: {nullspace.shape}")

## Feature Extraction

In [None]:
if nullspace.shape[0] > 0:
    extracted = extract_feature(representations, neighbor_indices, nullspace, verbose=True)
    
    # Check alignment with ground truth features
    alignments = torch.abs(extracted @ features.T)
    best_match = alignments.argmax().item()
    best_alignment = alignments[best_match].item()
    
    print(f"Best match: feature {best_match} with alignment {best_alignment:.4f}")
    print(f"Top 5 alignments: {alignments.topk(5)}")
else:
    print("Empty nullspace!")

## Manual Investigation

Add your own debugging cells below.

In [None]:
# Your debugging code here