# Feature Extraction with Target Selection

In [1]:
import torch
import sys
sys.path.insert(0, '..')

from src.config import SyntheticConfig, ExtractionConfig
from src.synthetic import generate_feature_basis, generate_representations
from src.extraction import (
    resolve_tau,
    compute_nullspace,
    extract_feature,
    find_neighbors,
    build_neighbor_matrix,
    find_monosemantic_targets,
)

torch.manual_seed(42)

<torch._C.Generator at 0x10d73b530>

In [3]:
syn_config = SyntheticConfig(
    d=5,
    n=5,
    num_representations=50,
    sparsity_mode="variable",
    k_min=1,
    k=3,
    coef_min_floor=0.3,  # Only used for non-BG modes
    positive_only=True,
)
ext_config = ExtractionConfig(
    epsilon=0.0,
)

print(f"d={syn_config.d}, n={syn_config.n}, k={syn_config.k}")
print(f"sparsity_mode={syn_config.sparsity_mode}")
print(f"tau={ext_config.tau}, neg_tau={ext_config.neg_tau}")

d=5, n=5, k=3
sparsity_mode=variable
tau=None, neg_tau=None


In [5]:
features = generate_feature_basis(syn_config.d, syn_config.n)
representations, coefficients = generate_representations(features, syn_config)

print(f"Features shape: {features.shape}")
print(f"Representations shape: {representations.shape}")
print(f"Coefficients shape: {coefficients.shape}")

AttributeError: 'FeatureBasisResult' object has no attribute 'shape'

In [72]:
print(f"{((coefficients != 0).sum(dim=1) == 1).sum()} monosemantic representations")

16 monosemantic representations


In [73]:
tau = resolve_tau(ext_config, syn_config)
print(f"Using tau = {tau:.4f}")

neighbor_matrix = build_neighbor_matrix(representations, tau)
target_indices = find_monosemantic_targets(neighbor_matrix)

Using tau = 0.0150


In [74]:
def get_active_features(coef_row: torch.Tensor) -> set[int]:
    return set(torch.where(coef_row != 0)[0].tolist())

In [75]:
target_indices

tensor([ 0,  3, 10, 19, 22])

In [76]:
print(f"{target_indices.shape[0]} found targets")

for t in target_indices:
    target = t.item()
    print(f"Representation {target} has features {get_active_features(coefficients[t])}")

5 found targets
Representation 0 has features {2}
Representation 3 has features {0}
Representation 10 has features {1}
Representation 19 has features {4}
Representation 22 has features {3}


## Do extraction with the first feature

In [88]:
idx = target_indices[0]
feature_idx = get_active_features(coefficients[idx]).pop()
neighbor_set = torch.where(neighbor_matrix[idx])[0]

In [96]:
nullspace = compute_nullspace(
    representations,
    neighbor_set,
    0.0,
)

In [97]:
features[feature_idx].shape, nullspace.shape

(torch.Size([5]), torch.Size([1, 5]))

In [98]:
features[feature_idx] @ nullspace.T

tensor([1.])

In [77]:
for idx in target_indices:
    idx = int(idx)
    neighbor_mask = neighbor_matrix[idx]
    neighbor_indices = torch.where(neighbor_mask)[0]

    nullspace = compute_nullspace(
        representations,
        neighbor_indices,
        0.0,
    )
    if nullspace.shape[0] == 0:
        continue
    feature = extract_feature(representations, neighbor_indices, nullspace)
    extracted_features.append(feature)

if len(extracted_features) == 0:
    extracted_features = torch.empty(0, representations.shape[1])

extracted_features = torch.stack(extracted_features)  # (m, d)