# Bias Feature Cluster Visualization (UMAP)

This notebook visualizes the distribution of bias-related features in a 2D space using UMAP dimensionality reduction.

**Purpose:**
- Visualize how bias features cluster in the SAE latent space
- Compare feature distributions across 9 demographic dimensions
- Identify shared vs. unique bias features

**Input Data:**
- SAE decoder weights: [100,000 × 4,096]
- IG² attribution results: top-k features per demographic

**Output:**
- 3×3 grid of UMAP scatter plots
- Feature frequency histogram

In [None]:
import os
import sys
import warnings
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Add project root to path
PROJECT_ROOT = Path(os.getcwd()).parent.parent
sys.path.append(str(PROJECT_ROOT))

from src.visualization import (
    setup_korean_font,
    load_demographics,
    load_ig2_results,
    load_sae_decoder_weights,
    prepare_umap_data,
    plot_umap_clusters,
    plot_feature_frequency_histogram,
    get_demographic_labels
)

warnings.filterwarnings('ignore')

print(f"Project root: {PROJECT_ROOT}")

## Configuration

In [None]:
# Setup Korean font for matplotlib
setup_korean_font()

# Seaborn style
sns.set_style('whitegrid')
sns.set_context('paper')

In [None]:
# Paths
DATA_DIR = PROJECT_ROOT / "data"
RESULTS_DIR = PROJECT_ROOT / "results"
ASSETS_DIR = PROJECT_ROOT / "notebooks" / "visualizations" / "assets"
ASSETS_DIR.mkdir(exist_ok=True, parents=True)

# Stage: 'pilot', 'medium', 'full', or 'mock'
STAGE = "mock"

# Visualization parameters
TOP_K = 100  # Number of top features per demographic

print(f"Data stage: {STAGE}")
print(f"Top-k features: {TOP_K}")

## Load Data

In [None]:
# Load demographics
demographics_dict = load_demographics(DATA_DIR)
demographic_labels_ko, demographic_labels_en = get_demographic_labels(demographics_dict)

print(f"Demographics: {len(demographic_labels_ko)}")
for ko, en in zip(demographic_labels_ko, demographic_labels_en):
    print(f"  - {ko} ({en})")

In [None]:
# Load IG² results
ig2_results = load_ig2_results(RESULTS_DIR, stage=STAGE)

print(f"\nIG² results loaded for {len(ig2_results)} demographics")
for demo, data in ig2_results.items():
    scores = data['feature_scores'] if isinstance(data, dict) else data
    print(f"  - {demo}: {len(scores)} features")

In [None]:
# Load SAE decoder weights
decoder_weights = load_sae_decoder_weights(RESULTS_DIR, stage=STAGE)

print(f"\nDecoder weights shape: {decoder_weights.shape}")
print(f"  Features: {decoder_weights.shape[0]:,}")
print(f"  Latent dim: {decoder_weights.shape[1]:,}")

## Apply UMAP Dimensionality Reduction

In [None]:
# Prepare UMAP data
embeddings, all_features, demographic2topfeatures = prepare_umap_data(
    decoder_weights=decoder_weights,
    ig2_results=ig2_results,
    top_k=TOP_K
)

print(f"\nUMAP embeddings shape: {embeddings.shape}")
print(f"Total unique features selected: {len(all_features)}")
print(f"\nTop features per demographic:")
for demo, features in demographic2topfeatures.items():
    print(f"  - {demo}: {len(features)} features")

## Visualize Feature Clusters

In [None]:
# Plot UMAP clusters
fig = plot_umap_clusters(
    embeddings=embeddings,
    all_features=all_features,
    demographic2topfeatures=demographic2topfeatures,
    demographic_labels_ko=demographic_labels_ko,
    demographic_labels_en=demographic_labels_en,
    save_path=ASSETS_DIR / f"umap_bias_features_{STAGE}_top{TOP_K}.png",
    figsize=(18, 18),
    top_k=TOP_K
)

plt.show()

## Feature Frequency Analysis

In [None]:
# Plot feature frequency histogram
fig = plot_feature_frequency_histogram(
    demographic2topfeatures=demographic2topfeatures,
    save_path=ASSETS_DIR / f"feature_frequency_{STAGE}_top{TOP_K}.png",
    figsize=(16, 4)
)

plt.show()

## Feature Overlap Analysis

In [None]:
from src.visualization.umap_utils import compute_feature_overlap, get_feature_frequency

# Compute pairwise overlaps
overlaps = compute_feature_overlap(demographic2topfeatures)

print("\nPairwise Feature Overlaps:")
print("=" * 60)
for (demo1, demo2), overlap in sorted(overlaps.items(), key=lambda x: -x[1])[:10]:
    print(f"{demo1:15s} ↔ {demo2:15s}: {overlap:3d} shared features")

# Feature frequency
feature_freq = get_feature_frequency(demographic2topfeatures)

print(f"\n\nFeature Frequency Statistics:")
print("=" * 60)
freq_counts = {}
for freq in feature_freq.values():
    freq_counts[freq] = freq_counts.get(freq, 0) + 1

for freq in sorted(freq_counts.keys(), reverse=True):
    count = freq_counts[freq]
    print(f"Features appearing in {freq} demographics: {count}")

## Summary Statistics

In [None]:
from src.visualization.feature_selection import get_shared_features, get_unique_features

# Shared features (appear in multiple demographics)
shared_features = get_shared_features(demographic2topfeatures, min_demographics=2)
print(f"\nShared features (≥2 demographics): {len(shared_features)}")

highly_shared = get_shared_features(demographic2topfeatures, min_demographics=5)
print(f"Highly shared features (≥5 demographics): {len(highly_shared)}")

# Unique features per demographic
unique_features = get_unique_features(demographic2topfeatures)
print(f"\nUnique features per demographic:")
for demo, features in unique_features.items():
    print(f"  - {demo}: {len(features)} unique features")

## Interpretation

### What to Look For:

1. **Clustering Patterns:**
   - Do bias features for a demographic cluster together?
   - Are there distinct regions for different demographics?
   - Do some demographics share feature clusters?

2. **Feature Overlap:**
   - Which demographics have the most shared features?
   - Are there "universal" bias features appearing across many demographics?
   - Which demographics have mostly unique features?

3. **Spatial Distribution:**
   - Are positive/negative bias features separated?
   - Do related demographics (e.g., gender, sexuality) have nearby clusters?
   - Are there outlier features far from the main clusters?

### Next Steps:

1. Examine IG² rankings for top features
2. Analyze activation patterns across prompts
3. Test suppression/amplification effects
4. Investigate specific high-frequency features