In [None]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import scipy.io as sio
import os
import glob
import pandas as pd
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from functools import partial
import random
import time
import json
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Import custom modules
import sys
sys.path.append('../')

from src.models.maet_model import MAET
from src.data.seedvii_dataset import SEEDVII_Dataset
from src.training.train_multimodal import MultimodalTrainer, train_subject_dependent, train_cross_subject
from src.utils.safe_forward import safe_model_forward
from src.training.evaluation_utils import compute_metrics, plot_confusion_matrix

In [None]:
#Configuration and Setup
def set_seed(seed: int = 42):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Configuration
CONFIG = {
    'seed': 42,
    'batch_size': 32,
    'num_epochs': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'val_split': 0.2,
    'num_workers': 4,
    'gradient_reversal_alpha': 0.5,
    'domain_loss_weight': 0.1,
    'early_stopping': 10,
    'save_model': True,
    'model': {
        'embed_dim': 32,
        'depth': 3,
        'num_heads': 4,
        'domain_generalization': True
    }
}

# Dataset configuration
DATA_DIR = "../data/SEED-VII"
SUBSET_RATIO = 0.5  # Use 50% of data for quick experiments

print(f"Configuration loaded")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

In [None]:
#  Dataset Loading and Exploration
print("=== DATASET EXPLORATION ===")

# Load dataset for exploration
dataset = SEEDVII_Dataset(DATA_DIR, 'multimodal', SUBSET_RATIO)

print(f"Dataset size: {len(dataset)} samples")
print(f"EEG features shape: {dataset.eeg_features.shape}")
print(f"Eye features shape: {dataset.eye_features.shape}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of subjects: {len(np.unique(dataset.subject_labels))}")

# Analyze class distribution
emotion_counts = pd.Series(dataset.emotion_labels).value_counts().sort_index()
print("\\nEmotion class distribution:")
for i, count in emotion_counts.items():
    emotion_name = dataset.emotion_labels.get(i, f"Class_{i}")
    print(f"  {i}: {emotion_name} - {count} samples")

# Analyze subject distribution  
subject_counts = pd.Series(dataset.subject_labels).value_counts().sort_index()
print(f"\\nSubject distribution:")
print(f"  Subjects: {subject_counts.index.tolist()}")
print(f"  Samples per subject: {subject_counts.values.tolist()}")

In [None]:
# Model Architecture Exploration
print("=== MODEL ARCHITECTURE ===")

# Create sample model
sample_model = MAET(
    eeg_dim=310,
    eye_dim=33, 
    num_classes=7,
    embed_dim=32,
    depth=3,
    num_heads=4,
    domain_generalization=True,
    num_domains=20
)

# Count parameters
total_params = sum(p.numel() for p in sample_model.parameters())
trainable_params = sum(p.numel() for p in sample_model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass with sample data
sample_batch = dataset
eeg_sample = sample_batch['eeg'].unsqueeze(0)
eye_sample = sample_batch['eye'].unsqueeze(0)

with torch.no_grad():
    # Test different modality combinations
    print("\\nTesting forward passes:")
    
    # Multimodal
    pred_multi, additional = safe_model_forward(sample_model, eeg=eeg_sample, eye=eye_sample)
    print(f"  Multimodal: {pred_multi.shape}")
    if additional:
        print(f"    Domain output: {additional.shape}")
    
    # EEG only
    pred_eeg, _ = safe_model_forward(sample_model, eeg=eeg_sample, eye=None)
    print(f"  EEG only: {pred_eeg.shape}")
    
    # Eye only  
    pred_eye, _ = safe_model_forward(sample_model, eeg=None, eye=eye_sample)
    print(f"  Eye only: {pred_eye.shape}")

del sample_model  # Clean up

In [None]:
# Single Subject Experiment (Quick Test)
print("=== SINGLE SUBJECT EXPERIMENT ===")

# Select one subject for detailed analysis
target_subject = 0
subject_mask = dataset.subject_labels == target_subject
subject_indices = np.where(subject_mask)

print(f"Subject {target_subject}: {len(subject_indices)} samples")

if len(subject_indices) >= 20:  # Ensure sufficient samples
    # Create subject-specific dataset
    train_indices, val_indices = train_test_split(
        subject_indices, test_size=0.3, random_state=42,
        stratify=dataset.emotion_labels[subject_indices]
    )
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Quick training
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MAET(eeg_dim=310, eye_dim=33, num_classes=7, embed_dim=32).to(device)
    
    trainer = MultimodalTrainer(CONFIG)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    print("\\nQuick training (10 epochs):")
    train_history = []
    
    for epoch in range(10):
        # Train
        train_loss, train_acc = trainer.train_epoch(model, train_loader, optimizer, criterion)
        
        # Validate
        val_loss, val_acc, val_f1, val_preds, val_labels = trainer.validate(
            model, val_loader, criterion
        )
        
        train_history.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'val_f1': val_f1
        })
        
        if epoch % 2 == 0:
            print(f"  Epoch {epoch+1}: Train {train_acc:.1f}%, Val {val_acc:.1f}%")
    
    # Plot training history
    epochs = [h['epoch'] for h in train_history]
    train_accs = [h['train_acc'] for h in train_history]
    val_accs = [h['val_acc'] for h in train_history]
    
    plt.figure(figsize=(10, 4))
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title(f'Subject {target_subject} Training Progress')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    print(f"\\nBest validation accuracy: {max(val_accs):.2f}%")
    
    del model, trainer  # Clean up

In [None]:
# Multimodal vs Unimodal Comparison
print("=== MODALITY COMPARISON ===")

modalities = ['multimodal', 'eeg', 'eye']
results_comparison = {}

for modality in modalities:
    print(f"\\nTesting {modality} modality...")
    
    # Load modality-specific dataset
    mod_dataset = SEEDVII_Dataset(DATA_DIR, modality, SUBSET_RATIO)
    
    # Use first subject for quick test
    subject_mask = mod_dataset.subject_labels == 0
    subject_indices = np.where(subject_mask)
    
    if len(subject_indices) >= 20:
        train_indices, val_indices = train_test_split(
            subject_indices, test_size=0.3, random_state=42,
            stratify=mod_dataset.emotion_labels[subject_indices]
        )
        
        train_dataset = Subset(mod_dataset, train_indices)
        val_dataset = Subset(mod_dataset, val_indices)
        
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
        
        # Create modality-specific model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = MAET(
            eeg_dim=310 if modality in ['eeg', 'multimodal'] else 0,
            eye_dim=33 if modality in ['eye', 'multimodal'] else 0,
            num_classes=7,
            embed_dim=32
        ).to(device)
        
        trainer = MultimodalTrainer(CONFIG)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()
        
        # Quick training
        best_val_acc = 0.0
        for epoch in range(15):
            train_loss, train_acc = trainer.train_epoch(model, train_loader, optimizer, criterion)
            val_loss, val_acc, val_f1, _, _ = trainer.validate(model, val_loader, criterion)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
        
        results_comparison[modality] = best_val_acc
        print(f"  {modality}: {best_val_acc:.2f}%")
        
        del model, trainer

# Comparison visualization
if results_comparison:
    plt.figure(figsize=(8, 5))
    modalities = list(results_comparison.keys())
    accuracies = list(results_comparison.values())
    
    bars = plt.bar(modalities, accuracies, color=['skyblue', 'lightgreen', 'salmon'])
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.ylabel('Best Validation Accuracy (%)')
    plt.title('Modality Comparison (Subject 0)')
    plt.ylim(0, max(accuracies) * 1.2)
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("\\nModality Ranking:")
    sorted_results = sorted(results_comparison.items(), key=lambda x: x, reverse=True)
    for rank, (modality, accuracy) in enumerate(sorted_results, 1):
        print(f"  {rank}. {modality}: {accuracy:.2f}%")

In [None]:
# Cross-Subject Analysis Preview
print("=== CROSS-SUBJECT ANALYSIS PREVIEW ===")

# Analyze subject variability
subjects_to_analyze = np.unique(dataset.subject_labels)[:5]  # First 5 subjects
subject_performance = {}

print("Quick cross-subject generalization test...")
print("Training on subjects [0, 1, 2], testing on subject 3")

# Prepare cross-subject data
train_subjects = [0, 1, 2]
test_subject = 3

train_mask = np.isin(dataset.subject_labels, train_subjects)
test_mask = dataset.subject_labels == test_subject

train_indices = np.where(train_mask)
test_indices = np.where(test_mask)

if len(test_indices) >= 10:
    print(f"Train samples: {len(train_indices)}")
    print(f"Test samples: {len(test_indices)}")
    
    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MAET(
        eeg_dim=310, eye_dim=33, num_classes=7, embed_dim=32,
        domain_generalization=True, num_domains=len(train_subjects)
    ).to(device)
    
    trainer = MultimodalTrainer(CONFIG)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    print("\\nTraining progress:")
    for epoch in range(20):
        train_loss, train_acc = trainer.train_epoch(model, train_loader, optimizer, criterion)
        if epoch % 5 == 0:
            print(f"  Epoch {epoch}: Train {train_acc:.1f}%")
    
    # Test
    test_loss, test_acc, test_f1, test_preds, test_labels = trainer.validate(
        model, test_loader, criterion
    )
    
    print(f"\\nCross-subject test results:")
    print(f"  Test accuracy: {test_acc:.2f}%")
    print(f"  Test F1-score: {test_f1:.2f}%")
    
    # Confusion matrix
    cm = confusion_matrix(test_labels, test_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=dataset.emotion_labels.values(),
                yticklabels=dataset.emotion_labels.values())
    plt.title(f'Cross-Subject Confusion Matrix (Test Subject {test_subject})')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    del model, trainer

# %%
# Cell 8: Attention Visualization (if applicable)
print("=== ATTENTION ANALYSIS ===")

# This cell would contain attention weight visualization code
# For now, we'll create a placeholder for future implementation

print("Attention visualization functionality to be implemented...")
print("This would include:")
print("- Visualization of multi-head attention weights")
print("- Analysis of which EEG/eye features are most important")
print("- Temporal attention patterns")
print("- Cross-modal attention interactions")

# Placeholder for attention extraction
def extract_attention_weights(model, dataloader, num_samples=5):
    """Extract attention weights from transformer blocks"""
    # This would extract attention weights during forward pass
    # for visualization and analysis
    pass

# %%
# Cell 9: Hyperparameter Analysis
print("=== HYPERPARAMETER SENSITIVITY ===")

# Test different embedding dimensions
embed_dims = [16, 32, 64]
embed_results = {}

print("Testing different embedding dimensions...")

for embed_dim in embed_dims:
    print(f"\\nTesting embed_dim={embed_dim}...")
    
    # Use small dataset for quick test
    quick_indices = np.random.choice(len(dataset), 1000, replace=False)
    quick_dataset = Subset(dataset, quick_indices)
    
    train_indices, val_indices = train_test_split(
        range(len(quick_dataset)), test_size=0.3, random_state=42
    )
    
    train_subset = Subset(quick_dataset, train_indices)
    val_subset = Subset(quick_dataset, val_indices)
    
    train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)
    
    # Create model with specific embed_dim
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MAET(
        eeg_dim=310, eye_dim=33, num_classes=7,
        embed_dim=embed_dim, depth=2, num_heads=2
    ).to(device)
    
    trainer = MultimodalTrainer(CONFIG)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Quick training
    best_val_acc = 0.0
    for epoch in range(10):
        train_loss, train_acc = trainer.train_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_acc, val_f1, _, _ = trainer.validate(model, val_loader, criterion)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    embed_results[embed_dim] = best_val_acc
    print(f"  Best validation accuracy: {best_val_acc:.2f}%")
    
    del model, trainer

# Visualize hyperparameter results
if embed_results:
    plt.figure(figsize=(8, 5))
    dims = list(embed_results.keys())
    accs = list(embed_results.values())
    
    plt.plot(dims, accs, 'bo-', linewidth=2, markersize=8)
    plt.xlabel('Embedding Dimension')
    plt.ylabel('Best Validation Accuracy (%)')
    plt.title('Embedding Dimension vs Performance')
    plt.grid(True, alpha=0.3)
    plt.xticks(dims)
    
    # Annotate points
    for dim, acc in zip(dims, accs):
        plt.annotate(f'{acc:.1f}%', (dim, acc), textcoords="offset points",
                    xytext=(0,10), ha='center')
    
    plt.show()

# %%
# Cell 10: Summary and Next Steps
print("=== EXPERIMENT SUMMARY ===")

print("\\nCompleted experiments:")
print("✓ Dataset exploration and analysis")
print("✓ Model architecture testing")
print("✓ Single subject training validation")
print("✓ Multimodal vs unimodal comparison")
print("✓ Cross-subject generalization preview")
print("✓ Hyperparameter sensitivity analysis")

print("\\nKey findings:")
if 'results_comparison' in locals():
    best_modality = max(results_comparison, key=results_comparison.get)
    print(f"- Best performing modality: {best_modality} ({results_comparison[best_modality]:.1f}%)")

if 'embed_results' in locals():
    best_embed = max(embed_results, key=embed_results.get)
    print(f"- Optimal embedding dimension: {best_embed} ({embed_results[best_embed]:.1f}%)")

print("\\nNext steps for full experiments:")
print("- Run complete subject-dependent validation")
print("- Perform full cross-subject (LOSO) evaluation")
print("- Implement attention visualization")
print("- Conduct comprehensive ablation studies")
print("- Optimize hyperparameters with proper grid search")
print("- Add statistical significance testing")

print("\\nTo run full experiments:")
print("1. Increase SUBSET_RATIO to 1.0 for full dataset")
print("2. Run train_subject_dependent() and train_cross_subject()")
print("3. Use proper cross-validation with multiple random seeds")
print("4. Implement early stopping and model checkpointing")

# %%
# Cell 11: Cleanup
print("=== CLEANUP ===")

# Clear variables and free memory
if 'dataset' in locals():
    del dataset
if 'quick_dataset' in locals():
    del quick_dataset

# Clear GPU memory if using CUDA
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
print("Memory cleanup completed")
print("Notebook execution finished successfully!")