In [4]:
"""
NOTEBOOK 11: Late Fusion - Ensemble of Expert Models

SAVE AS: notebooks/modeling/05_late_fusion.ipynb

WHAT THIS DOES:
- Uses already-trained Audio, Text, Video models
- Combines their predictions (weighted average)
- Finds optimal weights for fusion
- Compares with early fusion
- Robust to missing modalities
"""

# ========== CELL 1: Import Libraries ==========
print("Importing libraries...")

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ========== CELL 2: Load Data ==========
print("\nLoading data...")

PROCESSED_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\data\processed')
MODELS_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\models\saved_models')
RESULTS_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\results')

train_df = pd.read_csv(PROCESSED_DIR / 'train_data.csv')
val_df   = pd.read_csv(PROCESSED_DIR / 'val_data.csv')
test_df  = pd.read_csv(PROCESSED_DIR / 'test_data.csv')

print(f"✓ Data loaded")

# ========== CELL 3: Identify Features by Modality ==========
print("\nIdentifying features...")

audio_cols = [c for c in train_df.columns if any(x in c for x in ['mfcc','pitch','energy','spectral','zcr','rolloff','duration'])]
text_cols  = [c for c in train_df.columns if 'bert' in c.lower() or any(x in c.lower() for x in ['word','positive','negative','question'])]
video_cols = [c for c in train_df.columns if 'AU' in c or 'gaze' in c.lower() or any(x in c for x in ['Tx','Ty','Tz','Rx','Ry','Rz'])]

print(f"  Audio: {len(audio_cols)} features")
print(f"  Text: {len(text_cols)} features")
print(f"  Video: {len(video_cols)} features")

# ========== CELL 4: Prepare Data ==========
print("\nPreparing modality-specific data...")

def prepare_modality_data(df, feature_cols):
    if len(feature_cols) == 0:
        return None, df['PHQ8_Score'].values, None
    X = df[feature_cols].values
    y = df['PHQ8_Score'].values
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    return X, y, scaler

X_train_audio, y_train, scaler_audio = prepare_modality_data(train_df, audio_cols)
X_val_audio, _, _ = prepare_modality_data(val_df, audio_cols)
X_test_audio, _, _ = prepare_modality_data(test_df, audio_cols)

if scaler_audio:
    X_val_audio = scaler_audio.transform(val_df[audio_cols].values)
    X_test_audio = scaler_audio.transform(test_df[audio_cols].values)

X_train_text, _, scaler_text = prepare_modality_data(train_df, text_cols)
if X_train_text is not None:
    X_val_text = scaler_text.transform(val_df[text_cols].values)
    X_test_text = scaler_text.transform(test_df[text_cols].values)
else:
    X_val_text, X_test_text = None, None

X_train_video, _, scaler_video = prepare_modality_data(train_df, video_cols)
if X_train_video is not None:
    X_val_video = scaler_video.transform(val_df[video_cols].values)
    X_test_video = scaler_video.transform(test_df[video_cols].values)
else:
    X_val_video, X_test_video = None, None

print("✓ Data prepared for all modalities")

# ========== CELL 5: Load or Create Models ==========
print("\nLoading trained models...")

def create_simple_model(input_size):
    return nn.Sequential(
        nn.Linear(input_size, 128),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(64, 1)
    )

audio_model = create_simple_model(len(audio_cols)).to(device) if X_train_audio is not None else None
text_model  = create_simple_model(len(text_cols)).to(device) if X_train_text is not None else None
video_model = create_simple_model(len(video_cols)).to(device) if X_train_video is not None else None

try:
    if audio_model: audio_model.load_state_dict(torch.load(MODELS_DIR / 'audio_lstm_best.pth', map_location=device))
    print("✓ Audio model loaded")
except: print("⚠ Audio model missing, using untrained")

try:
    if text_model: text_model.load_state_dict(torch.load(MODELS_DIR / 'text_bert_best.pth', map_location=device))
    print("✓ Text model loaded")
except: print("⚠ Text model missing, using untrained")

try:
    if video_model: video_model.load_state_dict(torch.load(MODELS_DIR / 'video_lstm_best.pth', map_location=device))
    print("✓ Video model loaded")
except: print("⚠ Video model missing, using untrained")

# ========== CELL 6: Predictions ==========
def get_predictions_safe(model, X, device):
    if X is None or model is None:
        return None
    model.eval()
    X_tensor = torch.FloatTensor(X).to(device)
    with torch.no_grad():
        pred = model(X_tensor)
        if len(pred.shape) > 1:
            pred = pred.squeeze()
    return pred.cpu().numpy()

train_pred_audio = get_predictions_safe(audio_model, X_train_audio, device)
train_pred_text  = get_predictions_safe(text_model, X_train_text, device)
train_pred_video = get_predictions_safe(video_model, X_train_video, device)

val_pred_audio = get_predictions_safe(audio_model, X_val_audio, device)
val_pred_text  = get_predictions_safe(text_model, X_val_text, device)
val_pred_video = get_predictions_safe(video_model, X_val_video, device)

test_pred_audio = get_predictions_safe(audio_model, X_test_audio, device)
test_pred_text  = get_predictions_safe(text_model, X_test_text, device)
test_pred_video = get_predictions_safe(video_model, X_test_video, device)

print("✓ Predictions computed")

# ========== CELL 7: Simple Average Fusion ==========
val_preds  = [p for p in [val_pred_audio, val_pred_text, val_pred_video] if p is not None]
test_preds = [p for p in [test_pred_audio, test_pred_text, test_pred_video] if p is not None]

val_pred_avg  = np.mean(val_preds, axis=0)
test_pred_avg = np.mean(test_preds, axis=0)

val_mae_avg  = mean_absolute_error(y_val, val_pred_avg)
test_mae_avg = mean_absolute_error(y_test, test_pred_avg)
test_rmse_avg = np.sqrt(mean_squared_error(y_test, test_pred_avg))
test_r2_avg   = r2_score(y_test, test_pred_avg)

print(f"Simple Average Fusion → Test MAE: {test_mae_avg:.4f}, R²: {test_r2_avg:.4f}")

# ========== CELL 8: Weighted Fusion ==========
X_train_fusion = np.column_stack([p for p in [train_pred_audio, train_pred_text, train_pred_video] if p is not None])
X_val_fusion   = np.column_stack([p for p in [val_pred_audio, val_pred_text, val_pred_video] if p is not None])
X_test_fusion  = np.column_stack([p for p in [test_pred_audio, test_pred_text, test_pred_video] if p is not None])

fusion_model = Ridge(alpha=1.0)
fusion_model.fit(X_train_fusion, y_train)

weights = fusion_model.coef_
bias    = fusion_model.intercept_
weights_normalized = np.abs(weights) / np.sum(np.abs(weights))

val_pred_weighted  = fusion_model.predict(X_val_fusion)
test_pred_weighted = fusion_model.predict(X_test_fusion)

val_mae_weighted  = mean_absolute_error(y_val, val_pred_weighted)
test_mae_weighted = mean_absolute_error(y_test, test_pred_weighted)
test_rmse_weighted = np.sqrt(mean_squared_error(y_test, test_pred_weighted))
test_r2_weighted   = r2_score(y_test, test_pred_weighted)

print(f"Weighted Fusion → Test MAE: {test_mae_weighted:.4f}, R²: {test_r2_weighted:.4f}")
print(f"Fusion Weights (normalized): {weights_normalized}")

# ========== CELL 9: Individual Performance ==========
def compute_mae(pred, y):
    return mean_absolute_error(y, pred) if pred is not None else None

audio_mae = compute_mae(test_pred_audio, y_test)
text_mae  = compute_mae(test_pred_text, y_test)
video_mae = compute_mae(test_pred_video, y_test)

print(f"\nIndividual MAEs → Audio: {audio_mae}, Text: {text_mae}, Video: {video_mae}")

# ========== CELL 10: Save Results ==========
results = {
    'model': ['Late Fusion (Average)', 'Late Fusion (Weighted)'],
    'method': ['simple_average', 'learned_weights'],
    'val_mae': [val_mae_avg, val_mae_weighted],
    'test_mae': [test_mae_avg, test_mae_weighted],
    'test_rmse': [test_rmse_avg, test_rmse_weighted],
    'test_r2': [test_r2_avg, test_r2_weighted]
}

results_df = pd.DataFrame(results)
results_df.to_csv(RESULTS_DIR / 'metrics' / 'late_fusion_results.csv', index=False)
print(f"✓ Late fusion results saved")

weights_df = pd.DataFrame({
    'modality': [m for m,p in zip(['audio','text','video'], [train_pred_audio, train_pred_text, train_pred_video]) if p is not None],
    'weight': weights,
    'normalized': weights_normalized
})
weights_df.to_csv(RESULTS_DIR / 'metrics' / 'fusion_weights.csv', index=False)
print(f"✓ Fusion weights saved")


Importing libraries...
✓ Libraries imported
Using device: cpu

Loading data...
✓ Data loaded

Identifying features...
  Audio: 68 features
  Text: 0 features
  Video: 72 features

Preparing modality-specific data...
✓ Data prepared for all modalities

Loading trained models...
⚠ Audio model missing, using untrained
✓ Text model loaded
⚠ Video model missing, using untrained
✓ Predictions computed
Simple Average Fusion → Test MAE: 10.1560, R²: -1.8379
Weighted Fusion → Test MAE: 6.4185, R²: -0.2014
Fusion Weights (normalized): [0.5419234  0.45807663]

Individual MAEs → Audio: 10.14736270904541, Text: None, Video: 10.164694786071777
✓ Late fusion results saved
✓ Fusion weights saved
