# EEG-SAE Feature Analysis

This notebook provides interactive analysis of Sparse Autoencoder features trained on REVE (EEG transformer) activations.

**Prerequisites:** Run the training and feature computation scripts first:
```bash
bash scripts/01_run_train.sh            # Train SAE
bash scripts/02_run_compute_features.sh  # Compute feature statistics
```

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import plotly.io as pio
pio.renderers.default = 'notebook'

from analysis.utils import (
    load_feature_data,
    calculate_class_entropy,
    compute_class_selectivity,
    feature_overview_scatter,
    dead_feature_summary,
    plot_channel_time_activation,
    plot_channel_time_activation_plotly,
    plot_topomap,
    plot_feature_topomap,
    plot_per_class_activation,
    plot_per_class_activation_plotly,
    get_top_class_features,
    plot_eeg_trial,
    plot_activation_distribution,
    plot_feature_comparison_grid,
    BCIC2A_CLASS_NAMES,
)
from tasks.utils import BCIC2A_CHANNELS

## 1. Load Feature Statistics

Load the precomputed SAE feature data (mean activations, sparsity, per-class stats, top-activating trials).

In [None]:
# ── Configuration ──
FEATURE_DATA_PATH = '../results/feature_data/sae_feature_data.npz'
CHANNEL_NAMES = BCIC2A_CHANNELS
N_TIME_PATCHES = 5  # (1024 - 200) / (200 - 20) + 1 ≈ 5

# Load
stats = load_feature_data(FEATURE_DATA_PATH, device='cpu')

## 2. Feature Overview

### 2.1 Dead Feature Summary
How many features are alive vs dead (never fire)?

In [None]:
summary = dead_feature_summary(stats)
print(f"Total features: {summary['n_total']}")
print(f"Alive features: {summary['n_alive']} ({1 - summary['frac_dead']:.1%})")
print(f"Dead features:  {summary['n_dead']} ({summary['frac_dead']:.1%})")

### 2.2 Sparsity vs Activation Scatter Plot

Interactive scatter of all alive features. Each dot is one SAE feature:
- **x**: log₁₀(sparsity) — how often it fires
- **y**: log₁₀(mean activation) — how strongly it fires
- **color**: class entropy — low = class-specific, high = class-general

Hover to see feature index and selectivity.

In [None]:
alive_mask = summary['alive_mask']
fig = feature_overview_scatter(stats, mask=alive_mask)
fig.show()

### 2.3 Class Selectivity Distribution

Selectivity index: (max_class − mean_others) / (max_class + mean_others).  
High selectivity → feature fires primarily for one motor imagery class.

In [None]:
selectivity, preferred_class = compute_class_selectivity(stats['per_class_mean_acts'])

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Histogram of selectivity
axes[0].hist(selectivity[alive_mask].numpy(), bins=50, color='#3498db', edgecolor='black', linewidth=0.5)
axes[0].set_xlabel('Class Selectivity Index')
axes[0].set_ylabel('Count')
axes[0].set_title('Distribution of Class Selectivity (alive features)')
axes[0].axvline(x=0.3, color='red', linestyle='--', label='selectivity=0.3')
axes[0].legend()

# Preferred class distribution
class_counts = []
for cls_idx in range(4):
    count = ((preferred_class[alive_mask] == cls_idx) & (selectivity[alive_mask] > 0.1)).sum().item()
    class_counts.append(count)
axes[1].bar([BCIC2A_CLASS_NAMES[i] for i in range(4)], class_counts, 
            color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12'], edgecolor='black')
axes[1].set_ylabel('Number of selective features')
axes[1].set_title('Features per preferred class (selectivity > 0.1)')

plt.tight_layout()
plt.show()

## 3. Inspect Individual Features

Pick a specific SAE feature to examine in detail.

In [None]:
# ── Pick a feature to inspect ──
# Option A: Choose the most class-selective alive feature
sel_values = selectivity.clone()
sel_values[~alive_mask] = -1
FEATURE_IDX = int(sel_values.argmax().item())
print(f'Most selective alive feature: {FEATURE_IDX}')
print(f'  Selectivity: {selectivity[FEATURE_IDX]:.4f}')
print(f'  Preferred class: {BCIC2A_CLASS_NAMES[int(preferred_class[FEATURE_IDX])]}')
print(f'  Mean activation: {stats["mean_acts"][FEATURE_IDX]:.6f}')
print(f'  Sparsity: {stats["sparsity"][FEATURE_IDX]:.6f}')

### 3.1 Per-Class Activation

How does this feature activate across the 4 motor imagery classes?

In [None]:
fig = plot_per_class_activation_plotly(
    stats['per_class_mean_acts'].numpy(), FEATURE_IDX
)
fig.show()

### 3.2 Load a Sample Trial & Visualize Token-Level Activations

To see channel × time activation patterns, we need to run the SAE on a real trial.
This requires loading the REVE model and SAE — skip this section if you only want statistics.

In [None]:
# ── Load model and SAE ──
from src.sae_training.hooked_eeg_transformer import HookedEEGTransformer
from src.sae_training.sparse_autoencoder import SparseAutoencoder
from downstream.bcic2a_dataset import get_bcic2a_dataloaders

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAE_PATH = '../checkpoints/sae_latest.pt'  # adjust to your SAE checkpoint
DB_PATH = '/path/to/bcic2a.lmdb'            # adjust to your data path
BLOCK_LAYER = -2
MODULE_NAME = 'resid'

model = HookedEEGTransformer(
    model_name='brain-bzh/reve-base',
    channel_names=CHANNEL_NAMES,
    device=DEVICE,
)

sae = SparseAutoencoder.load_from_pretrained(SAE_PATH, device=DEVICE)
sae.eval()

loaders = get_bcic2a_dataloaders(db_path=DB_PATH, batch_size=16, test_subject_ids=[9])
print(f'Loaded model (d_in={model.embed_dim}) and SAE (d_sae={sae.d_sae})')

In [None]:
# ── Get a batch and compute SAE activations ──
batch = next(iter(loaders['test']))
eeg, labels = batch
eeg = eeg.to(DEVICE)

with torch.no_grad():
    activations = model.get_activations(eeg, BLOCK_LAYER, MODULE_NAME)
    feature_acts = sae.encode(activations.to(sae.dtype))  # [B, N_tokens, d_sae]

print(f'EEG shape: {eeg.shape}')
print(f'Activations shape: {activations.shape}')
print(f'SAE feature acts shape: {feature_acts.shape}')

### 3.3 Channel × Time Activation Heatmap

Shows how the selected SAE feature activates across channels (y-axis) and time patches (x-axis).  
This is the EEG analogue of PatchSAE's image patch segmentation mask.

In [None]:
# Pick a trial from the batch
trial_idx = 0
trial_acts = feature_acts[trial_idx].cpu().numpy()  # [N_tokens, d_sae]
trial_eeg = eeg[trial_idx].cpu().numpy()              # [C, T]
trial_label = labels[trial_idx].item()

print(f'Trial {trial_idx} — Class: {BCIC2A_CLASS_NAMES.get(trial_label, trial_label)}')

# Channel × Time heatmap
fig = plot_channel_time_activation_plotly(
    trial_acts, FEATURE_IDX, CHANNEL_NAMES, N_TIME_PATCHES,
    title=f'Feature {FEATURE_IDX} — {BCIC2A_CLASS_NAMES.get(trial_label, "?")} trial'
)
fig.show()

### 3.4 Topomap

Scalp topography showing which channels this feature focuses on (aggregated over time).

In [None]:
fig = plot_feature_topomap(
    trial_acts, FEATURE_IDX, CHANNEL_NAMES, N_TIME_PATCHES,
    aggregation='mean',
    title=f'Feature {FEATURE_IDX} Topomap'
)
plt.show()

### 3.5 Raw EEG of the Trial

Show the raw EEG with channels most activated by this feature highlighted.

In [None]:
# Find top-3 channels for this feature
acts_2d = trial_acts[:, FEATURE_IDX][:len(CHANNEL_NAMES) * N_TIME_PATCHES]
acts_2d = acts_2d.reshape(len(CHANNEL_NAMES), N_TIME_PATCHES)
channel_acts = acts_2d.mean(axis=1)
top_ch_idx = np.argsort(channel_acts)[-3:]
highlight = [CHANNEL_NAMES[i] for i in top_ch_idx]

fig = plot_eeg_trial(
    trial_eeg, CHANNEL_NAMES, sample_rate=250,
    title=f'EEG Trial (class={BCIC2A_CLASS_NAMES.get(trial_label, "?")}) — top channels for feature {FEATURE_IDX} highlighted',
    highlight_channels=highlight, scale=2.0,
)
plt.show()

### 3.6 Activation Distribution for One Trial

See which SAE features fire for this trial (bar chart across all features).

In [None]:
mean_pooled = trial_acts.mean(axis=0)  # [d_sae]
fig = plot_activation_distribution(
    mean_pooled, top_k=10,
    title=f'SAE Activation Distribution — {BCIC2A_CLASS_NAMES.get(trial_label, "?")} trial'
)
fig.show()

## 4. Top Features Per Class

Identify the most important SAE features for each motor imagery class.

In [None]:
for cls_idx in range(4):
    cls_name = BCIC2A_CLASS_NAMES[cls_idx]
    df = get_top_class_features(stats['per_class_mean_acts'].numpy(), cls_idx, top_k=10)
    print(f'\n=== Top 10 features for {cls_name} (class {cls_idx}) ===')
    print(df.to_string(index=False))

## 5. Multi-Feature Comparison Grid

Compare the top features side by side: channel×time heatmap + topomap + per-class bar chart.

In [None]:
# Pick top-4 most selective features
sel_alive = selectivity.clone()
sel_alive[~alive_mask] = -1
top_4_features = sel_alive.topk(4).indices.tolist()
print(f'Comparing features: {top_4_features}')

fig = plot_feature_comparison_grid(
    trial_acts, top_4_features, CHANNEL_NAMES, N_TIME_PATCHES,
    per_class_mean_acts=stats['per_class_mean_acts'].numpy(),
)
plt.show()

## 6. Feature Exploration: Class-Specific View

For each motor imagery class, visualize the top feature as topomap + heatmap.

In [None]:
per_class = stats['per_class_mean_acts'].numpy()

fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for cls_idx in range(4):
    top_feat = int(np.argmax(per_class[cls_idx]))
    cls_name = BCIC2A_CLASS_NAMES[cls_idx]
    
    # Row 1: Per-class bar chart
    ax = axes[0, cls_idx]
    ax.bar([BCIC2A_CLASS_NAMES[i] for i in range(4)], per_class[:, top_feat],
           color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12'], edgecolor='black', linewidth=0.5)
    ax.set_title(f'Top feature for {cls_name}\n(feat {top_feat})', fontsize=11)
    ax.set_ylabel('Mean activation')
    ax.tick_params(axis='x', rotation=30)
    
    # Row 2: Channel-time heatmap  
    ax2 = axes[1, cls_idx]
    acts = trial_acts[:, top_feat][:len(CHANNEL_NAMES) * N_TIME_PATCHES]
    heatmap = acts.reshape(len(CHANNEL_NAMES), N_TIME_PATCHES)
    im = ax2.imshow(heatmap, aspect='auto', cmap='RdYlBu_r', interpolation='nearest')
    ax2.set_yticks(range(len(CHANNEL_NAMES)))
    ax2.set_yticklabels(CHANNEL_NAMES, fontsize=7)
    ax2.set_xlabel('Time patch')
    ax2.set_title(f'Channel × Time', fontsize=11)
    plt.colorbar(im, ax=ax2, shrink=0.8)

plt.suptitle('Top feature per motor imagery class', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()