# IG² Attribution Rankings

This notebook visualizes the top features identified by IG² (Integrated Gradients squared) attribution.

**Purpose:**
- Identify the most important features for bias prediction per demographic
- Compare feature importance across demographics
- Understand which features have the highest attribution scores

**Input Data:**
- IG² attribution scores: [100,000] per demographic

**Output:**
- 3×3 grid of ranked bar charts (top-20 per demographic)
- Score distribution histograms

In [None]:
import os
import sys
import warnings
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Add project root to path (works from notebooks/visualizations/)
NOTEBOOK_DIR = Path(os.getcwd())
PROJECT_ROOT = NOTEBOOK_DIR.parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.visualization import (
    setup_korean_font,
    ensure_korean_font,
    load_demographics,
    load_ig2_results,
    plot_ig2_rankings,
    get_demographic_labels
)

warnings.filterwarnings('ignore')

print(f"Project root: {PROJECT_ROOT}")
print(f"Notebook dir: {NOTEBOOK_DIR}")

In [None]:
# Setup Korean font for matplotlib (improved version with auto-detection)
font_name = ensure_korean_font()

# Seaborn style
sns.set_style('whitegrid')
sns.set_context('paper')

In [None]:
# Configuration
DATA_DIR = PROJECT_ROOT / "data"
RESULTS_DIR = PROJECT_ROOT / "results"
ASSETS_DIR = PROJECT_ROOT / "notebooks" / "visualizations" / "assets"
ASSETS_DIR.mkdir(exist_ok=True, parents=True)

# Stage: 'pilot', 'medium', or 'full'
STAGE = "pilot"

# SAE configuration (for reference)
SAE_TYPE = "gated"  # 'standard' or 'gated'
LAYER_QUANTILE = "q2"  # 'q1', 'q2', or 'q3'

# Visualization parameters
TOP_K = 20

print(f"Data directory: {DATA_DIR}")
print(f"Results directory: {RESULTS_DIR}")
print(f"\nStage: {STAGE}")
print(f"SAE type: {SAE_TYPE}")
print(f"Layer quantile: {LAYER_QUANTILE}")
print(f"Top-K: {TOP_K}")

## Load Data

In [None]:
# Load demographics
demographics_dict = load_demographics(DATA_DIR)
demographic_labels_ko, demographic_labels_en = get_demographic_labels(demographics_dict)

# Load IG² results (loads from all per-demographic directories)
ig2_results = load_ig2_results(RESULTS_DIR, stage=STAGE)

print(f"Loaded IG² results for {len(ig2_results)} demographics")
for demo, data in ig2_results.items():
    print(f"  - {demo}")

## Plot Top-K Rankings

In [None]:
fig = plot_ig2_rankings(
    ig2_results=ig2_results,
    demographic_labels_ko=demographic_labels_ko,
    demographic_labels_en=demographic_labels_en,
    save_path=ASSETS_DIR / f"ig2_rankings_{STAGE}_top{TOP_K}.png",
    top_k=TOP_K,
    figsize=(20, 15)
)

plt.show()

## Score Distribution Analysis

In [None]:
# Plot score distributions
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (demo_ko, demo_en) in enumerate(zip(demographic_labels_ko, demographic_labels_en)):
    ax = axes[i]
    
    if demo_ko not in ig2_results:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center')
        ax.set_title(f"{demo_ko} ({demo_en})", fontsize=13)
        continue
    
    scores = ig2_results[demo_ko]
    if isinstance(scores, dict):
        scores = scores['feature_scores']
    
    scores = scores.cpu().numpy()
    
    # Histogram
    ax.hist(scores, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    ax.set_xlabel('IG² Score', fontsize=11)
    ax.set_ylabel('Frequency', fontsize=11)
    ax.set_title(f"{demo_ko} ({demo_en})", fontsize=13)
    ax.set_yscale('log')
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(ASSETS_DIR / f"ig2_score_distributions_{STAGE}.png", dpi=300, bbox_inches='tight')
plt.show()

## Comparative Statistics

In [None]:
# Compute statistics per demographic
stats = []

for demo_ko in demographic_labels_ko:
    if demo_ko not in ig2_results:
        continue
        
    scores = ig2_results[demo_ko]
    if isinstance(scores, dict):
        scores = scores['feature_scores']
    
    scores = scores.cpu().numpy()
    
    stats.append({
        'Demographic': demo_ko,
        'Max Score': scores.max(),
        'Mean Score': scores.mean(),
        'Median Score': np.median(scores),
        'Std Dev': scores.std(),
        'Top-10 Mean': np.sort(scores)[-10:].mean(),
        'Non-zero %': (scores > 0).mean() * 100
    })

df_stats = pd.DataFrame(stats)
print("\nIG² Score Statistics by Demographic:")
print("=" * 80)
print(df_stats.to_string(index=False))

# Save to CSV
df_stats.to_csv(ASSETS_DIR / f"ig2_statistics_{STAGE}.csv", index=False)
print(f"\nSaved to {ASSETS_DIR / f'ig2_statistics_{STAGE}.csv'}")

## Cross-Demographic Comparison

In [None]:
# Compare top scores across demographics
fig, ax = plt.subplots(figsize=(12, 6))

# Filter to available demographics
available_demos = [d for d in demographic_labels_ko if d in ig2_results]
x = np.arange(len(available_demos))

max_scores = []
for d in available_demos:
    scores = ig2_results[d]
    if isinstance(scores, dict):
        scores = scores['feature_scores']
    max_scores.append(scores.max().item())

ax.bar(x, max_scores, color='coral', alpha=0.7, edgecolor='black')
ax.set_xticks(x)
ax.set_xticklabels(available_demos, rotation=45, ha='right')
ax.set_ylabel('Max IG² Score', fontsize=13)
ax.set_title('Max IG² Score Comparison', fontsize=16, pad=15)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(ASSETS_DIR / f"ig2_max_scores_comparison_{STAGE}.png", dpi=300, bbox_inches='tight')
plt.show()

## Top Features Across All Demographics

In [None]:
# Find features that appear in top-K for multiple demographics
from collections import Counter

top_k_per_demo = 50
feature_counts = Counter()
feature_demo_mapping = {}

for demo_ko in available_demos:
    scores = ig2_results[demo_ko]
    if isinstance(scores, dict):
        scores = scores['feature_scores']
    
    top_indices = torch.topk(scores, k=top_k_per_demo).indices.tolist()
    feature_counts.update(top_indices)
    
    for idx in top_indices:
        if idx not in feature_demo_mapping:
            feature_demo_mapping[idx] = []
        feature_demo_mapping[idx].append(demo_ko)

# Show features appearing in multiple demographics
print(f"\nFeatures appearing in top-{top_k_per_demo} for multiple demographics:")
print("=" * 80)

multi_demo_features = [(feat, count) for feat, count in feature_counts.items() if count > 1]
multi_demo_features.sort(key=lambda x: -x[1])

for feat_idx, count in multi_demo_features[:20]:
    demos = feature_demo_mapping[feat_idx]
    print(f"Feature {feat_idx:6d}: appears in {count} demographics - {', '.join(demos[:5])}{'...' if len(demos) > 5 else ''}")

print(f"\nTotal unique features in top-{top_k_per_demo}: {len(feature_counts)}")
print(f"Features in multiple demographics: {len(multi_demo_features)}")

## Interpretation

### What to Look For:

1. **Score Magnitude:**
   - Which demographics have the highest attribution scores?
   - Are scores concentrated in a few features or distributed?

2. **Top Features:**
   - Which features consistently appear at the top?
   - Are there demographic-specific vs. shared top features?

3. **Score Distribution:**
   - Is the distribution power-law (few high, many low)?
   - What percentage of features have non-zero scores?

### Next Steps:

1. Examine activation patterns for top features
2. Test suppression effects on highest-scoring features
3. Investigate feature overlap with UMAP clustering