In [6]:
# ============================================================
# NOTEBOOK 16: Hierarchical Cross-Modal Attention (HCMA)
# ============================================================

# ========== CELL 1: Import Libraries ==========
print("Importing libraries...")

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported")

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ========== CELL 2: Load Data ==========
print("\nLoading data...")

PROCESSED_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\data\processed')
MODELS_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\models\saved_models')
RESULTS_DIR = Path(r'C:\Users\VIJAY BHUSHAN SINGH\depression_detection_project\results')

MODELS_DIR.mkdir(parents=True, exist_ok=True)
(RESULTS_DIR / 'figures').mkdir(parents=True, exist_ok=True)

train_df = pd.read_csv(PROCESSED_DIR / 'train_data.csv')
val_df = pd.read_csv(PROCESSED_DIR / 'val_data.csv')
test_df = pd.read_csv(PROCESSED_DIR / 'test_data.csv')

print(f"✓ Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# ========== CELL 3: Extract Modality Features ==========
print("\nExtracting modality-specific features...")

# Identify features by modality
audio_cols = [col for col in train_df.columns if any(x in col for x in 
              ['mfcc', 'pitch', 'energy', 'spectral', 'zcr', 'rolloff', 'duration'])]
text_cols = [col for col in train_df.columns if 'bert' in col.lower() or 
             any(x in col.lower() for x in ['word', 'positive', 'negative', 'question'])]
video_cols = [col for col in train_df.columns if 'AU' in col or 'gaze' in col.lower() or 
              any(x in col for x in ['Tx', 'Ty', 'Tz', 'Rx', 'Ry', 'Rz'])]

print(f"  Audio: {len(audio_cols)} features")
print(f"  Text: {len(text_cols)} features")
print(f"  Video: {len(video_cols)} features")

# Handle empty modalities by adding a dummy column
if len(text_cols) == 0:
    train_df['text_dummy'] = 0.0
    val_df['text_dummy'] = 0.0
    test_df['text_dummy'] = 0.0
    text_cols = ['text_dummy']
    print("⚠ No text features detected. Added dummy column.")

# Extract and normalize each modality
def prepare_modality(df, cols, scaler=None):
    X = df[cols].values
    if scaler is None:
        scaler = StandardScaler()
        X = scaler.fit_transform(X)
        return X, scaler
    else:
        X = scaler.transform(X)
        return X

# Audio
X_train_audio, audio_scaler = prepare_modality(train_df, audio_cols)
X_val_audio = prepare_modality(val_df, audio_cols, audio_scaler)
X_test_audio = prepare_modality(test_df, audio_cols, audio_scaler)

# Text
X_train_text, text_scaler = prepare_modality(train_df, text_cols)
X_val_text = prepare_modality(val_df, text_cols, text_scaler)
X_test_text = prepare_modality(test_df, text_cols, text_scaler)

# Video
X_train_video, video_scaler = prepare_modality(train_df, video_cols)
X_val_video = prepare_modality(val_df, video_cols, video_scaler)
X_test_video = prepare_modality(test_df, video_cols, video_scaler)

# Labels
y_train = train_df['PHQ8_Score'].values
y_val = val_df['PHQ8_Score'].values
y_test = test_df['PHQ8_Score'].values

print("✓ All modalities prepared and normalized")

# ========== CELL 4: Create Dataset ==========
class HCMADataset(Dataset):
    """Dataset for HCMA with separate modalities"""
    
    def __init__(self, audio, text, video, labels):
        self.audio = torch.FloatTensor(audio)
        self.text = torch.FloatTensor(text)
        self.video = torch.FloatTensor(video)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.audio[idx], self.text[idx], self.video[idx], self.labels[idx]

train_dataset = HCMADataset(X_train_audio, X_train_text, X_train_video, y_train)
val_dataset = HCMADataset(X_val_audio, X_val_text, X_val_video, y_val)
test_dataset = HCMADataset(X_test_audio, X_test_text, X_test_video, y_test)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"✓ Dataloaders created (batch_size={batch_size})")

# ============================================================
# Rest of your HCMA model code (Cells 5-10) remains the same
# ============================================================

# ✅ The critical fix: Empty text modality handled by adding dummy column.


# ========== CELL 5: Define HCMA Components ==========

class IntraModalAttention(nn.Module):
    """
    Level 1: Self-attention within a single modality
    
    Finds which features within the modality are important
    Example: In text, "depressed" word is more important than "the"
    """
    
    def __init__(self, input_dim, hidden_dim=64):
        super(IntraModalAttention, self).__init__()
        
        self.query = nn.Linear(input_dim, hidden_dim)
        self.key = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5
        
    def forward(self, x):
        # x: [batch, features]
        
        # Add sequence dimension
        x = x.unsqueeze(1)  # [batch, 1, features]
        
        # Compute Q, K, V
        Q = self.query(x)  # [batch, 1, hidden]
        K = self.key(x)
        V = self.value(x)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [batch, 1, 1]
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Apply attention
        attended = torch.matmul(attention_weights, V)  # [batch, 1, hidden]
        
        return attended.squeeze(1), attention_weights.squeeze()

class CrossModalAttention(nn.Module):
    """
    Level 2: Cross-attention between two modalities
    
    One modality (query) attends to another (key, value)
    Example: Text asks Video "do facial expressions match my words?"
    """
    
    def __init__(self, query_dim, kv_dim, hidden_dim=64):
        super(CrossModalAttention, self).__init__()
        
        self.query = nn.Linear(query_dim, hidden_dim)
        self.key = nn.Linear(kv_dim, hidden_dim)
        self.value = nn.Linear(kv_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5
        
    def forward(self, query_mod, kv_mod):
        # query_mod: [batch, query_features]
        # kv_mod: [batch, kv_features]
        
        # Add sequence dimension
        query_mod = query_mod.unsqueeze(1)
        kv_mod = kv_mod.unsqueeze(1)
        
        Q = self.query(query_mod)  # [batch, 1, hidden]
        K = self.key(kv_mod)
        V = self.value(kv_mod)
        
        # Cross-attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Apply attention
        attended = torch.matmul(attention_weights, V)
        
        return attended.squeeze(1), attention_weights.squeeze()

print("✓ Attention components defined")

# ========== CELL 6: Define Complete HCMA Model ==========

class HCMA(nn.Module):
    """
    Hierarchical Cross-Modal Attention Network
    
    Architecture:
    1. Project each modality to common dimension
    2. Level 1: Intra-modal self-attention
    3. Level 2: Cross-modal attention (all pairs)
    4. Level 3: Hierarchical fusion
    5. Final prediction
    """
    
    def __init__(self, audio_dim, text_dim, video_dim, hidden_dim=128, output_dim=1):
        super(HCMA, self).__init__()
        
        # === Modality Projection ===
        self.audio_proj = nn.Linear(audio_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.video_proj = nn.Linear(video_dim, hidden_dim)
        
        # === Level 1: Intra-Modal Attention ===
        self.audio_intra_attn = IntraModalAttention(hidden_dim, hidden_dim//2)
        self.text_intra_attn = IntraModalAttention(hidden_dim, hidden_dim//2)
        self.video_intra_attn = IntraModalAttention(hidden_dim, hidden_dim//2)
        
        # === Level 2: Cross-Modal Attention ===
        # Text <-> Video
        self.text_to_video_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        self.video_to_text_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        
        # Audio <-> Text
        self.audio_to_text_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        self.text_to_audio_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        
        # Audio <-> Video
        self.audio_to_video_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        self.video_to_audio_attn = CrossModalAttention(hidden_dim//2, hidden_dim//2, hidden_dim//4)
        
        # === Level 3: Hierarchical Fusion ===
        # Combine: intra-attended + cross-attended features
        fusion_dim = (hidden_dim//2) * 3 + (hidden_dim//4) * 6  # 3 intra + 6 cross
        
        self.fusion_layer1 = nn.Linear(fusion_dim, 256)
        self.fusion_bn1 = nn.BatchNorm1d(256)
        self.fusion_dropout1 = nn.Dropout(0.3)
        
        self.fusion_layer2 = nn.Linear(256, 128)
        self.fusion_bn2 = nn.BatchNorm1d(128)
        self.fusion_dropout2 = nn.Dropout(0.3)
        
        self.fusion_layer3 = nn.Linear(128, 64)
        self.fusion_bn3 = nn.BatchNorm1d(64)
        self.fusion_dropout3 = nn.Dropout(0.2)
        
        # Output
        self.output_layer = nn.Linear(64, output_dim)
        
    def forward(self, audio, text, video):
        # === Step 1: Project to common space ===
        audio_h = torch.relu(self.audio_proj(audio))  # [batch, hidden]
        text_h = torch.relu(self.text_proj(text))
        video_h = torch.relu(self.video_proj(video))
        
        # === Step 2: Level 1 - Intra-Modal Attention ===
        audio_intra, audio_intra_weights = self.audio_intra_attn(audio_h)
        text_intra, text_intra_weights = self.text_intra_attn(text_h)
        video_intra, video_intra_weights = self.video_intra_attn(video_h)
        
        # === Step 3: Level 2 - Cross-Modal Attention ===
        # Text <-> Video
        text_from_video, _ = self.text_to_video_attn(text_intra, video_intra)
        video_from_text, _ = self.video_to_text_attn(video_intra, text_intra)
        
        # Audio <-> Text
        audio_from_text, _ = self.audio_to_text_attn(audio_intra, text_intra)
        text_from_audio, _ = self.text_to_audio_attn(text_intra, audio_intra)
        
        # Audio <-> Video
        audio_from_video, _ = self.audio_to_video_attn(audio_intra, video_intra)
        video_from_audio, _ = self.video_to_audio_attn(video_intra, audio_intra)
        
        # === Step 4: Level 3 - Hierarchical Fusion ===
        # Concatenate all attended representations
        fused = torch.cat([
            audio_intra, text_intra, video_intra,  # Intra-modal
            text_from_video, video_from_text,      # Text-Video cross
            audio_from_text, text_from_audio,      # Audio-Text cross
            audio_from_video, video_from_audio     # Audio-Video cross
        ], dim=1)
        
        # Fusion layers
        out = torch.relu(self.fusion_bn1(self.fusion_layer1(fused)))
        out = self.fusion_dropout1(out)
        
        out = torch.relu(self.fusion_bn2(self.fusion_layer2(out)))
        out = self.fusion_dropout2(out)
        
        out = torch.relu(self.fusion_bn3(self.fusion_layer3(out)))
        out = self.fusion_dropout3(out)
        
        # Final prediction
        output = self.output_layer(out)
        
        # Return prediction and attention weights for explainability
        attention_weights = {
            'audio_intra': audio_intra_weights,
            'text_intra': text_intra_weights,
            'video_intra': video_intra_weights
        }
        
        return output.squeeze(), attention_weights

# Create model
model = HCMA(
    audio_dim=len(audio_cols),
    text_dim=len(text_cols),
    video_dim=len(video_cols),
    hidden_dim=128
)
model = model.to(device)

print("✓ HCMA Model created")
print(f"\nModel architecture:")
print(model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

# ========== CELL 7: Training Setup ==========
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

print("\n✓ Training setup complete")
print(f"  Loss: MSE")
print(f"  Optimizer: AdamW (lr=0.001)")
print(f"  Scheduler: ReduceLROnPlateau")

# ========== CELL 8: Training Functions ==========
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for audio, text, video, labels in dataloader:
        audio, text, video, labels = audio.to(device), text.to(device), video.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs, _ = model(audio, text, video)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend(outputs.detach().cpu().numpy())
        targets.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    mae = mean_absolute_error(targets, predictions)
    
    return avg_loss, mae

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    all_attention = []
    
    with torch.no_grad():
        for audio, text, video, labels in dataloader:
            audio, text, video, labels = audio.to(device), text.to(device), video.to(device), labels.to(device)
            
            outputs, attn_weights = model(audio, text, video)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            predictions.extend(outputs.cpu().numpy())
            targets.extend(labels.cpu().numpy())
            all_attention.append(attn_weights)
    
    avg_loss = total_loss / len(dataloader)
    mae = mean_absolute_error(targets, predictions)
    rmse = np.sqrt(mean_squared_error(targets, predictions))
    r2 = r2_score(targets, predictions)
    
    return avg_loss, mae, rmse, r2, predictions, targets, all_attention

print("✓ Training functions defined")

# ========== CELL 9: Train HCMA Model ==========
print("\n" + "="*60)
print("TRAINING HIERARCHICAL CROSS-MODAL ATTENTION (HCMA)")
print("="*60)
print("\n🎯 Target: MAE < 3.0")
print("⏰ This will take 30-40 minutes...\n")

num_epochs = 60
best_val_mae = float('inf')
patience = 12
patience_counter = 0

train_losses = []
val_losses = []
train_maes = []
val_maes = []

for epoch in range(num_epochs):
    train_loss, train_mae = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_mae, val_rmse, val_r2, _, _, _ = evaluate(model, val_loader, criterion, device)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_maes.append(train_mae)
    val_maes.append(val_mae)
    
    scheduler.step(val_mae)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}, MAE: {train_mae:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, MAE: {val_mae:.4f}, RMSE: {val_rmse:.4f}, R²: {val_r2:.4f}")
    
    if val_mae < best_val_mae:
        best_val_mae = val_mae
        torch.save(model.state_dict(), MODELS_DIR / 'hcma_best.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_mae': val_mae,
        }, MODELS_DIR / 'hcma_checkpoint.pth')
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

print(f"\n✓ Training complete!")
print(f"Best validation MAE: {best_val_mae:.4f}")

if best_val_mae < 3.0:
    print("\n🎉 TARGET ACHIEVED! MAE < 3.0 ✅")
else:
    print(f"\n⚠ Close to target. Current: {best_val_mae:.4f}, Target: < 3.0")

# ========== CELL 10: Save for Next Notebook ==========
# Save training history for analysis
history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_maes': train_maes,
    'val_maes': val_maes,
    'best_val_mae': best_val_mae
}

import pickle
with open(MODELS_DIR / 'hcma_training_history.pkl', 'wb') as f:
    pickle.dump(history, f)

print(f"\n✓ Model and history saved")
print(f"  Model: {MODELS_DIR / 'hcma_best.pth'}")
print(f"  Checkpoint: {MODELS_DIR / 'hcma_checkpoint.pth'}")
print(f"  History: {MODELS_DIR / 'hcma_training_history.pkl'}")

print("\n🎯 Next: Run Notebook 17 for training visualization and evaluation")
print("="*60)

Importing libraries...
✓ Libraries imported
Using device: cpu

Loading data...
✓ Train: 11, Val: 2, Test: 3

Extracting modality-specific features...
  Audio: 68 features
  Text: 0 features
  Video: 72 features
⚠ No text features detected. Added dummy column.
✓ All modalities prepared and normalized
✓ Dataloaders created (batch_size=4)
✓ Attention components defined
✓ HCMA Model created

Model architecture:
HCMA(
  (audio_proj): Linear(in_features=68, out_features=128, bias=True)
  (text_proj): Linear(in_features=1, out_features=128, bias=True)
  (video_proj): Linear(in_features=72, out_features=128, bias=True)
  (audio_intra_attn): IntraModalAttention(
    (query): Linear(in_features=128, out_features=64, bias=True)
    (key): Linear(in_features=128, out_features=64, bias=True)
    (value): Linear(in_features=128, out_features=64, bias=True)
  )
  (text_intra_attn): IntraModalAttention(
    (query): Linear(in_features=128, out_features=64, bias=True)
    (key): Linear(in_features=128,