# PhysioNet Motor Imagery - Complete Model Comparison

## Comprehensive Comparison of All 7 Models

This notebook compares:
- **EEG-ARNN Models**: Baseline, Adaptive Gating (with ES/AS/GS channel selection)
- **Legacy Methods**: FBCSP, CNN-SAE, EEGNet, ACS-SE-CNN, G-CARM (with channel selection)

## Analyses:
1. **Full-Channel Performance** - All 64 channels
2. **Channel Selection Performance** - Top-k channels (k=10,15,20,25,30)
3. **Accuracy Drop Analysis** - Robustness to channel reduction
4. **Optimal k-Value** - Best accuracy/channel trade-off
5. **Statistical Comparisons** - Significance tests
6. **World-Class Visualizations** - Publication-ready figures

## Input Files:
**Full-channel results (7 files):**
- `eeg_arnn_baseline_results.csv`, `eeg_arnn_adaptive_results.csv`
- `legacy_fbcsp_results.csv`, `legacy_cnn_sae_results.csv`, `legacy_eegnet_results.csv`, `legacy_acs_se_cnn_results.csv`, `legacy_g_carm_results.csv`

**Channel selection results (12 files):**
- `eeg_arnn_baseline_retrain_results.csv`, `eeg_arnn_adaptive_retrain_results.csv`
- `legacy_fbcsp_retrain_results.csv`, `legacy_cnn_sae_retrain_results.csv`, `legacy_eegnet_retrain_results.csv`, `legacy_acs_se_cnn_retrain_results.csv`, `legacy_g_carm_retrain_results.csv`

## Output:
- Comprehensive comparison tables
- Statistical test results
- 10+ publication-quality visualizations
- Summary CSVs

## 1. Setup and Imports

In [None]:
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
sns.set_context('notebook', font_scale=1.2)

print("Libraries loaded!")

## 2. Load All Results

In [None]:
results_dir = Path('results')

# Full-channel results
full_channel_files = {
    'Baseline EEG-ARNN': 'eeg_arnn_baseline_results.csv',
    'Adaptive Gating EEG-ARNN': 'eeg_arnn_adaptive_results.csv',
    'FBCSP': 'legacy_fbcsp_results.csv',
    'CNN-SAE': 'legacy_cnn_sae_results.csv',
    'EEGNet': 'legacy_eegnet_results.csv',
    'ACS-SE-CNN': 'legacy_acs_se_cnn_results.csv',
    'G-CARM': 'legacy_g_carm_results.csv'
}

# Channel selection (retrain) results
retrain_files = {
    'Baseline EEG-ARNN': 'eeg_arnn_baseline_retrain_results.csv',
    'Adaptive Gating EEG-ARNN': 'eeg_arnn_adaptive_retrain_results.csv',
    'FBCSP': 'legacy_fbcsp_retrain_results.csv',
    'CNN-SAE': 'legacy_cnn_sae_retrain_results.csv',
    'EEGNet': 'legacy_eegnet_retrain_results.csv',
    'ACS-SE-CNN': 'legacy_acs_se_cnn_retrain_results.csv',
    'G-CARM': 'legacy_g_carm_retrain_results.csv'
}

# Load full-channel results
full_results = {}
for model_name, filename in full_channel_files.items():
    filepath = results_dir / filename
    if filepath.exists():
        full_results[model_name] = pd.read_csv(filepath)
        print(f"Loaded full-channel: {model_name:30s} ({len(full_results[model_name])} subjects)")
    else:
        print(f"Warning: {filename} not found")

print()

# Load channel selection results
retrain_results = {}
for model_name, filename in retrain_files.items():
    filepath = results_dir / filename
    if filepath.exists():
        retrain_results[model_name] = pd.read_csv(filepath)
        n_rows = len(retrain_results[model_name])
        print(f"Loaded retrain: {model_name:30s} ({n_rows} rows)")
    else:
        print(f"Warning: {filename} not found")

print(f"\nFull-channel models loaded: {len(full_results)}")
print(f"Channel selection models loaded: {len(retrain_results)}")

## 3. Aggregate Statistics

In [None]:
# Full-channel summary
full_summary_data = []

for model_name, df in full_results.items():
    full_summary_data.append({
        'Model': model_name,
        'Channels': 64,
        'Accuracy': df['accuracy'].mean(),
        'Std_Acc': df['accuracy'].std(),
        'Precision': df['precision'].mean(),
        'Recall': df['recall'].mean(),
        'F1-Score': df['f1_score'].mean(),
        'AUC-ROC': df['auc_roc'].mean(),
        'Specificity': df['specificity'].mean(),
        'N_subjects': len(df)
    })

full_summary_df = pd.DataFrame(full_summary_data).sort_values('Accuracy', ascending=False)

print("\n" + "="*140)
print("FULL-CHANNEL PERFORMANCE (64 channels)")
print("="*140)
print(full_summary_df.to_string(index=False, float_format='%.4f'))
print("="*140)

# Best channel selection performance for each model
best_retrain_data = []

for model_name, df in retrain_results.items():
    if len(df) == 0:
        continue
    
    # Get unique methods for this model
    methods = df['method'].unique() if 'method' in df.columns else [model_name.upper()]
    
    for method in methods:
        method_df = df[df['method'] == method] if 'method' in df.columns else df
        
        # Find best k-value (highest accuracy)
        best_row = method_df.loc[method_df.groupby('k')['avg_accuracy'].mean().idxmax()]
        best_k = best_row['k'] if 'k' in best_row else method_df['k'].mode().iloc[0]
        
        best_k_df = method_df[method_df['k'] == best_k]
        
        best_retrain_data.append({
            'Model': model_name,
            'Method': method,
            'Best_k': int(best_k),
            'Accuracy': best_k_df['avg_accuracy'].mean(),
            'Std_Acc': best_k_df['avg_accuracy'].std(),
            'Acc_Drop': best_k_df['accuracy_drop'].mean(),
            'F1-Score': best_k_df['avg_f1_score'].mean(),
            'AUC-ROC': best_k_df['avg_auc_roc'].mean()
        })

best_retrain_df = pd.DataFrame(best_retrain_data).sort_values('Accuracy', ascending=False)

print("\n" + "="*140)
print("BEST CHANNEL SELECTION PERFORMANCE (optimal k per method)")
print("="*140)
print(best_retrain_df.to_string(index=False, float_format='%.4f'))
print("="*140)

## 4. Statistical Comparison

In [None]:
print("\n" + "="*80)
print("STATISTICAL SIGNIFICANCE TESTS (Paired t-test)")
print("="*80 + "\n")

if 'Adaptive Gating EEG-ARNN' in results and 'Baseline EEG-ARNN' in results:
    adaptive_acc = results['Adaptive Gating EEG-ARNN']['accuracy'].values
    baseline_acc = results['Baseline EEG-ARNN']['accuracy'].values
    
    if len(adaptive_acc) == len(baseline_acc):
        t_stat, p_value = stats.ttest_rel(adaptive_acc, baseline_acc)
        improvement = adaptive_acc.mean() - baseline_acc.mean()
        
        print("Adaptive Gating vs Baseline EEG-ARNN:")
        print(f"  Mean improvement: {improvement:.4f} ({improvement/baseline_acc.mean()*100:.2f}%)")
        print(f"  t-statistic: {t_stat:.4f}")
        print(f"  p-value: {p_value:.4f}")
        print(f"  Significant: {'Yes' if p_value < 0.05 else 'No'} (alpha=0.05)")
        print()

baseline_model = 'Baseline EEG-ARNN'
if baseline_model in results:
    baseline_acc = results[baseline_model]['accuracy'].values
    
    print(f"\nComparison against {baseline_model}:\n")
    
    for model_name, df in results.items():
        if model_name == baseline_model:
            continue
        
        model_acc = df['accuracy'].values
        
        if len(model_acc) == len(baseline_acc):
            t_stat, p_value = stats.ttest_rel(model_acc, baseline_acc)
            improvement = model_acc.mean() - baseline_acc.mean()
            
            print(f"{model_name}:")
            print(f"  Improvement: {improvement:+.4f} ({improvement/baseline_acc.mean()*100:+.2f}%)")
            print(f"  p-value: {p_value:.4f}")
            print(f"  Significant: {'Yes' if p_value < 0.05 else 'No'}")
            print()

## 5. Visualizations

### 5.1 Box Plot Comparison

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
fig.suptitle('Model Performance Comparison Across All Metrics', fontsize=16, fontweight='bold')

metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']
metric_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC', 'Specificity']

for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
    ax = axes[idx // 3, idx % 3]
    
    data_for_plot = []
    labels_for_plot = []
    
    for model_name, df in results.items():
        if metric in df.columns:
            data_for_plot.append(df[metric].values)
            labels_for_plot.append(model_name.replace(' EEG-ARNN', '').replace(' ', '\n'))
    
    bp = ax.boxplot(data_for_plot, labels=labels_for_plot, patch_artist=True)
    
    for patch in bp['boxes']:
        patch.set_facecolor('skyblue')
        patch.set_alpha(0.7)
    
    ax.set_ylabel(metric_name, fontsize=12)
    ax.set_title(f'{metric_name} Distribution', fontsize=12, fontweight='bold')
    ax.tick_params(axis='x', rotation=45)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('results/model_comparison_boxplots.png', dpi=150, bbox_inches='tight')
plt.show()

print("Box plots saved to results/model_comparison_boxplots.png")

### 5.2 Bar Chart - Mean Performance

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))

model_names = list(results.keys())
accuracies = [results[m]['accuracy'].mean() for m in model_names]
std_accs = [results[m]['accuracy'].std() for m in model_names]

x_pos = np.arange(len(model_names))
bars = ax.bar(x_pos, accuracies, yerr=std_accs, capsize=5, alpha=0.8, 
              color=['#2ecc71' if 'Adaptive' in m else '#3498db' if 'Baseline' in m else '#95a5a6' 
                     for m in model_names])

ax.set_xlabel('Model', fontsize=14, fontweight='bold')
ax.set_ylabel('Accuracy', fontsize=14, fontweight='bold')
ax.set_title('Mean Accuracy Comparison Across All Models', fontsize=16, fontweight='bold')
ax.set_xticks(x_pos)
ax.set_xticklabels([m.replace(' EEG-ARNN', '') for m in model_names], rotation=45, ha='right')
ax.set_ylim([min(accuracies) - 0.05, max(accuracies) + 0.05])
ax.grid(True, alpha=0.3, axis='y')

for i, (acc, std) in enumerate(zip(accuracies, std_accs)):
    ax.text(i, acc + std + 0.01, f'{acc:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('results/model_comparison_accuracy.png', dpi=150, bbox_inches='tight')
plt.show()

print("Bar chart saved to results/model_comparison_accuracy.png")

### 5.3 Radar Chart - Multi-Metric Comparison

In [None]:
from math import pi

metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']
metric_labels = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC', 'Specificity']

num_vars = len(metrics)
angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
angles += angles[:1]

fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(projection='polar'))

colors = ['#2ecc71', '#3498db', '#e74c3c', '#f39c12', '#9b59b6', '#1abc9c', '#34495e']

for idx, (model_name, df) in enumerate(results.items()):
    values = [df[m].mean() for m in metrics]
    values += values[:1]
    
    ax.plot(angles, values, 'o-', linewidth=2, label=model_name.replace(' EEG-ARNN', ''),
            color=colors[idx % len(colors)])
    ax.fill(angles, values, alpha=0.15, color=colors[idx % len(colors)])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metric_labels, size=12)
ax.set_ylim(0, 1.0)
ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], size=10)
ax.grid(True)
ax.set_title('Multi-Metric Performance Comparison', size=16, fontweight='bold', pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=10)

plt.tight_layout()
plt.savefig('results/model_comparison_radar.png', dpi=150, bbox_inches='tight')
plt.show()

print("Radar chart saved to results/model_comparison_radar.png")

## 6. Ranking Table

In [None]:
ranking_data = []

for model_name, df in results.items():
    ranking_data.append({
        'Model': model_name,
        'Mean_Accuracy': df['accuracy'].mean(),
        'Mean_F1': df['f1_score'].mean(),
        'Mean_AUC_ROC': df['auc_roc'].mean(),
    })

ranking_df = pd.DataFrame(ranking_data)
ranking_df['Overall_Score'] = (
    ranking_df['Mean_Accuracy'] * 0.4 + 
    ranking_df['Mean_F1'] * 0.3 + 
    ranking_df['Mean_AUC_ROC'] * 0.3
)
ranking_df = ranking_df.sort_values('Overall_Score', ascending=False).reset_index(drop=True)
ranking_df['Rank'] = range(1, len(ranking_df) + 1)

print("\n" + "="*100)
print("MODEL RANKING (Overall Score = 0.4*Acc + 0.3*F1 + 0.3*AUC)")
print("="*100)
print(ranking_df[['Rank', 'Model', 'Mean_Accuracy', 'Mean_F1', 'Mean_AUC_ROC', 'Overall_Score']].to_string(index=False))
print("="*100)

## 7. Save Final Summary

In [None]:
summary_df.to_csv('results/final_comparison_summary.csv', index=False)
ranking_df.to_csv('results/final_ranking.csv', index=False)

print("\nFinal summary saved:")
print("  - results/final_comparison_summary.csv")
print("  - results/final_ranking.csv")

print("\n" + "="*80)
print("COMPARISON COMPLETE!")
print("="*80)