In [None]:
# ========== CELL 1: Import Libraries ==========
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

RESULTS_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\results')

# ========== CELL 2: Load All Results ==========
print("Loading results from all models...")

audio_df = pd.read_csv(RESULTS_DIR / 'metrics' / 'audio_lstm_results.csv')
text_df = pd.read_csv(RESULTS_DIR / 'metrics' / 'text_bert_results.csv')
early_df = pd.read_csv(RESULTS_DIR / 'metrics' / 'early_fusion_results.csv')
late_df = pd.read_csv(RESULTS_DIR / 'metrics' / 'late_fusion_results.csv')

print("✓ All results loaded")

# ========== CELL 3: Create Comparison Table ==========
comparison = pd.DataFrame({
    'Model': ['Audio-LSTM', 'Text-BERT', 'Early Fusion', 
              'Late Fusion (Avg)', 'Late Fusion (Weighted)'],
    'Test_MAE': [
        audio_df['test_mae'].values[0],
        text_df['test_mae'].values[0],
        early_df['test_mae'].values[0],
        late_df[late_df['method']=='simple_average']['test_mae'].values[0],
        late_df[late_df['method']=='learned_weights']['test_mae'].values[0]
    ],
    'Test_R2': [
        audio_df['test_r2'].values[0],
        text_df['test_r2'].values[0],
        early_df['test_r2'].values[0],
        late_df[late_df['method']=='simple_average']['test_r2'].values[0],
        late_df[late_df['method']=='learned_weights']['test_r2'].values[0]
    ]
})

print("\n" + "="*60)
print("FUSION METHODS COMPARISON")
print("="*60)
print(comparison.to_string(index=False))

best_idx = comparison['Test_MAE'].idxmin()
print(f"\n🏆 BEST MODEL: {comparison.loc[best_idx, 'Model']}")
print(f"   Test MAE: {comparison.loc[best_idx, 'Test_MAE']:.4f}")

# ========== CELL 4: Visualize ==========
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Bar chart
colors = ['steelblue', 'green', 'purple', 'orange', 'red']
axes[0].bar(comparison['Model'], comparison['Test_MAE'], 
            color=colors, alpha=0.7, edgecolor='black')
axes[0].set_ylabel('Test MAE')
axes[0].set_title('Model Comparison - MAE (Lower is Better)')
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(alpha=0.3, axis='y')

for i, v in enumerate(comparison['Test_MAE']):
    axes[0].text(i, v + 0.05, f'{v:.2f}', ha='center', fontweight='bold')

# R² comparison
axes[1].bar(comparison['Model'], comparison['Test_R2'], 
            color=colors, alpha=0.7, edgecolor='black')
axes[1].set_ylabel('Test R²')
axes[1].set_title('Model Comparison - R² (Higher is Better)')
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(alpha=0.3, axis='y')

for i, v in enumerate(comparison['Test_R2']):
    axes[1].text(i, v + 0.01, f'{v:.2f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'figures' / 'fusion_comparison_week9.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Comparison complete!")
print("="*60)