# Late Fusion by Ambiguity Groups with EEG

This notebook extends the ambiguity group analysis to include EEG data.

**Modalities:**
1. Physiology (POST)
2. Behavior
3. Gaze
4. EEG ← NEW

**Analysis:**
- Split trials into ambiguity groups (Low=0, Medium=3, High=6)
- Run late fusion for each group with and without EEG
- Compare EEG contribution across ambiguity levels

In [1]:
import pickle
import json
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import accuracy_score, f1_score
from sklearn.impute import SimpleImputer
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
sns.set_style('whitegrid')

## 1. Load Features

In [2]:
# Load existing featureswith open('../../data/results/features_POST/extracted_features_POST.pkl', 'rb') as f:    feature_data = pickle.load(f)merged_df = feature_data['merged_df']physio_cols = feature_data['physio_cols']behavior_cols = feature_data['behavior_cols']gaze_cols = feature_data['gaze_cols']# Load EEG featureswith open('../../data/results/features_POST/eeg_features_POST.pkl', 'rb') as f:    eeg_data = pickle.load(f)eeg_features_df = eeg_data['eeg_features_df']eeg_cols = eeg_data['feature_columns']print(f"✓ Loaded features")print(f"  Original trials: {len(merged_df)}")print(f"  EEG trials: {len(eeg_features_df)}")


✓ Loaded features
  Original trials: 12511
  EEG trials: 10


In [3]:
# Merge EEG features
merged_with_eeg = merged_df.merge(
    eeg_features_df,
    on=['subject_id', 'trial_id'],
    how='inner'
)

print(f"✓ Merged data: {len(merged_with_eeg)} trials")
print(f"  Subjects: {merged_with_eeg['subject_id'].nunique()}")
print(f"\nAmbiguity distribution:")
print(merged_with_eeg['ambiguity'].value_counts().sort_index())

KeyError: 'subject_id'

## 2. Late Fusion Function

In [None]:
def weighted_late_fusion(X_modalities, y, subjects, modality_names):
    """
    Weighted late fusion using logistic regression meta-learner.
    Returns subject-level accuracy for proper SEM calculation.
    """
    logo = LeaveOneGroupOut()
    base_models = [RandomForestClassifier(n_estimators=100, max_depth=5, 
                                          min_samples_split=10, min_samples_leaf=5, 
                                          random_state=42,
                                          class_weight='balanced')
                   for _ in X_modalities]
    
    subject_accs = {}
    subject_f1s = {}
    all_weights = []
    
    for train_idx, test_idx in logo.split(X_modalities[0], y, subjects):
        train_probs, test_probs = [], []
        
        for X, model in zip(X_modalities, base_models):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train = y[train_idx]
            
            model.fit(X_train, y_train)
            train_probs.append(model.predict_proba(X_train)[:, 1])
            test_probs.append(model.predict_proba(X_test)[:, 1])
        
        train_probs = np.column_stack(train_probs)
        test_probs = np.column_stack(test_probs)
        y_train, y_test = y[train_idx], y[test_idx]
        
        meta = LogisticRegression(random_state=42, max_iter=1000)
        meta.fit(train_probs, y_train)
        weights = meta.coef_[0]
        y_pred = meta.predict(test_probs)
        
        test_subject = subjects[test_idx][0]
        acc = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
        subject_accs[test_subject] = acc
        subject_f1s[test_subject] = f1
        all_weights.append(weights)
    
    subject_acc_values = np.array(list(subject_accs.values()))
    subject_f1_values = np.array(list(subject_f1s.values()))
    
    avg_weights = np.mean(all_weights, axis=0)
    norm_weights = np.exp(avg_weights) / np.sum(np.exp(avg_weights))
    
    return {
        'accuracy_mean': np.mean(subject_acc_values),
        'accuracy_sem': stats.sem(subject_acc_values),
        'accuracy_std': np.std(subject_acc_values),
        'f1_mean': np.mean(subject_f1_values),
        'f1_sem': stats.sem(subject_f1_values),
        'weights': norm_weights,
        'n_trials': len(y),
        'n_subjects': len(subject_accs),
    }

## 3. Create Ambiguity Groups

In [None]:
# Create ambiguity groups
merged_with_eeg['ambiguity_group'] = merged_with_eeg['ambiguity'].replace({0:'Low', 3:'Medium', 6:'High'})

print("Ambiguity group distribution:")
print(merged_with_eeg['ambiguity_group'].value_counts())

## 4. Run Late Fusion for Each Ambiguity Group

In [None]:
modality_names_no_eeg = ['Physiology (POST)', 'Behavior', 'Gaze']
modality_names_with_eeg = ['Physiology (POST)', 'Behavior', 'Gaze', 'EEG']

group_results_no_eeg = {}
group_results_with_eeg = {}

for group in ['Low', 'Medium', 'High']:
    print(f"\n{'='*80}")
    print(f"Ambiguity Group: {group}")
    print(f"{'='*80}")
    
    group_data = merged_with_eeg[merged_with_eeg['ambiguity_group'] == group]
    
    n_subjects = group_data['subject_id'].nunique()
    print(f"Trials: {len(group_data)}")
    print(f"Subjects: {n_subjects}")
    print(f"Outcome distribution: {group_data['outcome'].value_counts().to_dict()}")
    
    if n_subjects < 3:
        print(f"⚠ Skipping - insufficient subjects")
        continue
    
    # Prepare features
    X_physio = SimpleImputer(strategy='mean').fit_transform(group_data[physio_cols])
    X_behavior = SimpleImputer(strategy='mean').fit_transform(group_data[behavior_cols])
    X_gaze = SimpleImputer(strategy='mean').fit_transform(group_data[gaze_cols])
    X_eeg = SimpleImputer(strategy='mean').fit_transform(group_data[eeg_cols])
    y = group_data['outcome'].values
    subjects = group_data['subject_id'].values
    
    # WITHOUT EEG
    print(f"\n--- WITHOUT EEG ---")
    X_modalities_no_eeg = [X_physio, X_behavior, X_gaze]
    results_no_eeg = weighted_late_fusion(X_modalities_no_eeg, y, subjects, modality_names_no_eeg)
    group_results_no_eeg[group] = results_no_eeg
    
    print(f"Accuracy: {results_no_eeg['accuracy_mean']:.3f} ± {results_no_eeg['accuracy_sem']:.3f} (SEM)")
    print(f"Modality Weights:")
    for name, w in zip(modality_names_no_eeg, results_no_eeg['weights']):
        print(f"  {name}: {w:.3f}")
    
    # WITH EEG
    print(f"\n--- WITH EEG ---")
    X_modalities_with_eeg = [X_physio, X_behavior, X_gaze, X_eeg]
    results_with_eeg = weighted_late_fusion(X_modalities_with_eeg, y, subjects, modality_names_with_eeg)
    group_results_with_eeg[group] = results_with_eeg
    
    print(f"Accuracy: {results_with_eeg['accuracy_mean']:.3f} ± {results_with_eeg['accuracy_sem']:.3f} (SEM)")
    print(f"Modality Weights:")
    for name, w in zip(modality_names_with_eeg, results_with_eeg['weights']):
        print(f"  {name}: {w:.3f}")
    
    # EEG contribution
    improvement = results_with_eeg['accuracy_mean'] - results_no_eeg['accuracy_mean']
    print(f"\nEEG Contribution: {improvement:+.3f} ({(improvement/results_no_eeg['accuracy_mean']*100):+.2f}%)")

## 5. Comparison Across Groups

In [None]:
# Create comparison dataframe
comparison_data = []

for group in ['Low', 'Medium', 'High']:
    if group in group_results_no_eeg:
        res_no_eeg = group_results_no_eeg[group]
        res_with_eeg = group_results_with_eeg[group]
        
        comparison_data.append({
            'Group': group,
            'N_Trials': res_with_eeg['n_trials'],
            'N_Subjects': res_with_eeg['n_subjects'],
            'Acc_No_EEG': res_no_eeg['accuracy_mean'],
            'Acc_No_EEG_SEM': res_no_eeg['accuracy_sem'],
            'Acc_With_EEG': res_with_eeg['accuracy_mean'],
            'Acc_With_EEG_SEM': res_with_eeg['accuracy_sem'],
            'EEG_Improvement': res_with_eeg['accuracy_mean'] - res_no_eeg['accuracy_mean'],
            'EEG_Weight': res_with_eeg['weights'][-1],
            'Physio_Weight_NoEEG': res_no_eeg['weights'][0],
            'Behavior_Weight_NoEEG': res_no_eeg['weights'][1],
            'Gaze_Weight_NoEEG': res_no_eeg['weights'][2],
            'Physio_Weight_WithEEG': res_with_eeg['weights'][0],
            'Behavior_Weight_WithEEG': res_with_eeg['weights'][1],
            'Gaze_Weight_WithEEG': res_with_eeg['weights'][2],
        })

comparison_df = pd.DataFrame(comparison_data)

print("\n" + "="*80)
print("COMPARISON ACROSS AMBIGUITY GROUPS")
print("="*80)
print(comparison_df[['Group', 'N_Trials', 'Acc_No_EEG', 'Acc_With_EEG', 'EEG_Improvement', 'EEG_Weight']].to_string(index=False))

## 6. Visualizations

In [None]:
# Accuracy comparison with/without EEG
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy comparison
ax = axes[0]
x = np.arange(len(comparison_df))
width = 0.35

ax.bar(x - width/2, comparison_df['Acc_No_EEG'], width, 
       yerr=comparison_df['Acc_No_EEG_SEM'], capsize=5,
       label='Without EEG', color='steelblue', alpha=0.7)
ax.bar(x + width/2, comparison_df['Acc_With_EEG'], width,
       yerr=comparison_df['Acc_With_EEG_SEM'], capsize=5,
       label='With EEG', color='coral', alpha=0.7)

ax.set_xlabel('Ambiguity Group')
ax.set_ylabel('Accuracy')
ax.set_title('Accuracy: With vs Without EEG')
ax.set_xticks(x)
ax.set_xticklabels(comparison_df['Group'])
ax.set_ylim([0, 1])
ax.legend()
ax.grid(alpha=0.3, axis='y')
ax.axhline(0.5, color='red', linestyle='--', alpha=0.3)

# EEG improvement
ax = axes[1]
bars = ax.bar(comparison_df['Group'], comparison_df['EEG_Improvement'], 
              color='mediumseagreen', alpha=0.7)
ax.set_xlabel('Ambiguity Group')
ax.set_ylabel('Accuracy Improvement')
ax.set_title('EEG Contribution by Ambiguity')
ax.axhline(0, color='black', linestyle='-', linewidth=0.8)
ax.grid(alpha=0.3, axis='y')

# Add value labels
for i, v in enumerate(comparison_df['EEG_Improvement']):
    ax.text(i, v + 0.005, f'{v:+.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
# EEG weight by ambiguity group
fig, ax = plt.subplots(figsize=(10, 6))

ax.bar(comparison_df['Group'], comparison_df['EEG_Weight'], 
       color='orange', alpha=0.7)
ax.set_xlabel('Ambiguity Group')
ax.set_ylabel('EEG Weight in Fusion')
ax.set_title('EEG Modality Weight by Ambiguity Level')
ax.grid(alpha=0.3, axis='y')

# Add value labels
for i, v in enumerate(comparison_df['EEG_Weight']):
    ax.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
# Modality weights comparison (with vs without EEG)
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

for idx, group in enumerate(['Low', 'Medium', 'High']):
    # Without EEG
    ax = axes[0, idx]
    weights_no_eeg = [
        comparison_df.loc[comparison_df['Group'] == group, 'Physio_Weight_NoEEG'].values[0],
        comparison_df.loc[comparison_df['Group'] == group, 'Behavior_Weight_NoEEG'].values[0],
        comparison_df.loc[comparison_df['Group'] == group, 'Gaze_Weight_NoEEG'].values[0]
    ]
    ax.bar(modality_names_no_eeg, weights_no_eeg, color='steelblue', alpha=0.7)
    ax.set_title(f'{group} - Without EEG')
    ax.set_ylabel('Weight')
    ax.set_ylim([0, 1])
    ax.grid(alpha=0.3, axis='y')
    
    # With EEG
    ax = axes[1, idx]
    weights_with_eeg = [
        comparison_df.loc[comparison_df['Group'] == group, 'Physio_Weight_WithEEG'].values[0],
        comparison_df.loc[comparison_df['Group'] == group, 'Behavior_Weight_WithEEG'].values[0],
        comparison_df.loc[comparison_df['Group'] == group, 'Gaze_Weight_WithEEG'].values[0],
        comparison_df.loc[comparison_df['Group'] == group, 'EEG_Weight'].values[0]
    ]
    colors = ['steelblue', 'coral', 'mediumseagreen', 'orange']
    ax.bar(modality_names_with_eeg, weights_with_eeg, color=colors, alpha=0.7)
    ax.set_title(f'{group} - With EEG')
    ax.set_ylabel('Weight')
    ax.set_ylim([0, 1])
    ax.tick_params(axis='x', rotation=45)
    ax.grid(alpha=0.3, axis='y')

plt.suptitle('Modality Weights by Ambiguity Group', fontsize=14, y=1.00)
plt.tight_layout()
plt.show()

## 7. Summary

In [None]:
print("\n" + "="*80)
print("SUMMARY: EEG CONTRIBUTION ACROSS AMBIGUITY LEVELS")
print("="*80)

for _, row in comparison_df.iterrows():
    print(f"\n{row['Group']} Ambiguity:")
    print(f"  Accuracy without EEG: {row['Acc_No_EEG']:.3f} ± {row['Acc_No_EEG_SEM']:.3f}")
    print(f"  Accuracy with EEG:    {row['Acc_With_EEG']:.3f} ± {row['Acc_With_EEG_SEM']:.3f}")
    print(f"  EEG improvement:      {row['EEG_Improvement']:+.3f} ({(row['EEG_Improvement']/row['Acc_No_EEG']*100):+.2f}%)")
    print(f"  EEG weight in fusion: {row['EEG_Weight']:.3f} ({row['EEG_Weight']*100:.1f}%)")

print("\n" + "="*80)
print("KEY FINDINGS:")
print("="*80)

max_improvement_group = comparison_df.loc[comparison_df['EEG_Improvement'].idxmax(), 'Group']
max_improvement_val = comparison_df.loc[comparison_df['EEG_Improvement'].idxmax(), 'EEG_Improvement']
max_weight_group = comparison_df.loc[comparison_df['EEG_Weight'].idxmax(), 'Group']
max_weight_val = comparison_df.loc[comparison_df['EEG_Weight'].idxmax(), 'EEG_Weight']

print(f"1. EEG provides largest improvement in {max_improvement_group} ambiguity ({max_improvement_val:+.3f})")
print(f"2. EEG has highest weight in {max_weight_group} ambiguity ({max_weight_val:.3f})")
print(f"3. EEG contribution varies across ambiguity levels")

print("\n" + "="*80)