In [2]:
# Fix Missing Import
from torch.nn import GRUCell
print("✅ GRUCell import added successfully!")

✅ GRUCell import added successfully!


# Advanced Brain GNN Models for Neuroimaging Analysis

Implementation of five state-of-the-art Graph Neural Network architectures for brain connectivity analysis:

1. **BrainGNN** - ROI-aware convolutions with interpretable pooling for biomarker discovery
2. **Local-to-Global GNN (LG-GNN)** - Hierarchical learning from ROI to subject relationships
3. **Dynamic Multi-Site GCN (DG-DMSGCN)** - Multi-site adaptation with temporal features
4. **IFC-GNN** - Temporal functional connectivity interactions with deep feature selection
5. **RAGNN** - Hemispheric asymmetry learning for EEG-based analysis

Each model addresses specific challenges in brain connectivity analysis and neurological disorder classification.

In [3]:
# Core Dependencies and Imports
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRUCell
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, GATConv, SAGEConv
from torch_geometric.data import Data, Batch
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Set device and random seeds for reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")



Using device: cuda
PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
GPU: NVIDIA GeForce RTX 2050


In [4]:
# Data Loading and Preprocessing Functions
class ABIDEDataset(Dataset):
    """ABIDE Dataset for brain connectivity analysis"""
    
    def __init__(self, data_dir: str, phenotypic_file: str = None):
        self.data_dir = data_dir
        self.phenotypic_file = phenotypic_file
        self.connectivity_matrices = []
        self.time_series_data = []  # Store original time series
        self.labels = []
        self.subjects = []
        self.sites = []
        self.demographic_features = []
        self.load_data()
        
    def load_data(self):
        """Load time series data and compute connectivity matrices"""
        print("Loading ABIDE time series data...")
        
        # Load phenotypic data if available
        phenotypic_data = None
        if self.phenotypic_file and os.path.exists(self.phenotypic_file):
            phenotypic_data = pd.read_csv(self.phenotypic_file)
            print(f"Loaded phenotypic data with {len(phenotypic_data)} subjects")
        
        # Process ROI time series files
        roi_dir = os.path.join(self.data_dir, 'Outputs', 'cpac', 'nofilt_noglobal', 'rois_cc400')
        if not os.path.exists(roi_dir):
            raise FileNotFoundError(f"ROI directory not found: {roi_dir}")
            
        roi_files = [f for f in os.listdir(roi_dir) if f.endswith('.1D')]
        print(f"Found {len(roi_files)} time series files")
        
        for roi_file in roi_files:
            try:
                # Extract subject information from filename
                parts = roi_file.replace('_rois_cc400.1D', '').split('_')
                if len(parts) >= 2:
                    site = parts[0]
                    subject_id = '_'.join(parts[1:])
                else:
                    site = 'Unknown'
                    subject_id = parts[0]
                
                # Load time series data
                file_path = os.path.join(roi_dir, roi_file)
                time_series = np.loadtxt(file_path)
                
                # Handle different data formats
                if time_series.ndim == 1:
                    # Single time point - reshape to (1, n_rois)
                    time_series = time_series.reshape(1, -1)
                
                # Calculate connectivity matrix from time series
                # time_series shape: (n_timepoints, n_rois)
                connectivity_matrix = np.corrcoef(time_series.T)  # Correlation between ROIs
                
                # Handle NaN values that might occur with constant time series
                connectivity_matrix = np.nan_to_num(connectivity_matrix, nan=0.0)
                
                # Ensure diagonal is 1
                np.fill_diagonal(connectivity_matrix, 1.0)
                
                # Get phenotypic information
                dx_group = 1  # Default: ASD
                age = 25.0    # Default age
                sex = 1       # Default: Male
                
                if phenotypic_data is not None:
                    # Try to match subject
                    subject_matches = phenotypic_data[
                        (phenotypic_data['SITE_ID'] == site) & 
                        (phenotypic_data['SUB_ID'].astype(str) == subject_id)
                    ]
                    if not subject_matches.empty:
                        subject_info = subject_matches.iloc[0]
                        dx_group = subject_info.get('DX_GROUP', 1)
                        age = subject_info.get('AGE_AT_SCAN', 25.0)
                        sex = subject_info.get('SEX', 1)
                
                # Store data
                self.connectivity_matrices.append(connectivity_matrix)
                self.time_series_data.append(time_series)  # Store original time series
                self.labels.append(dx_group - 1)  # Convert to 0/1
                self.subjects.append(f"{site}_{subject_id}")
                self.sites.append(site)
                self.demographic_features.append([age, sex])
                
            except Exception as e:
                print(f"Error processing {roi_file}: {e}")
                continue
        
        print(f"Successfully loaded {len(self.connectivity_matrices)} subjects")
        print(f"Time series shape: {self.time_series_data[0].shape if self.time_series_data else 'N/A'}")
        print(f"Connectivity matrix shape: {self.connectivity_matrices[0].shape if self.connectivity_matrices else 'N/A'}")
        print(f"Label distribution: {np.bincount(self.labels)}")
        
    def __len__(self):
        return len(self.connectivity_matrices)
    
    def __getitem__(self, idx):
        connectivity = torch.FloatTensor(self.connectivity_matrices[idx])
        time_series = torch.FloatTensor(self.time_series_data[idx])
        label = torch.LongTensor([self.labels[idx]])
        demographics = torch.FloatTensor(self.demographic_features[idx])
        site = self.sites[idx]
        
        return {
            'connectivity': connectivity,
            'time_series': time_series,  # Original time series data
            'label': label,
            'demographics': demographics,
            'site': site,
            'subject': self.subjects[idx]
        }

def create_graph_from_connectivity(connectivity_matrix, threshold=0.3):
    """Create PyTorch Geometric graph from connectivity matrix"""
    n_nodes = connectivity_matrix.shape[0]
    
    # Apply threshold and create adjacency matrix
    adj_matrix = (np.abs(connectivity_matrix) > threshold).astype(float)
    
    # Get edge indices
    edge_index = np.where(adj_matrix)
    edge_index = torch.LongTensor(np.vstack(edge_index))
    
    # Edge weights
    edge_weights = connectivity_matrix[edge_index[0], edge_index[1]]
    edge_attr = torch.FloatTensor(edge_weights).unsqueeze(1)
    
    # Node features (ROI indices and connectivity statistics)
    node_features = []
    for i in range(n_nodes):
        roi_connectivity = connectivity_matrix[i]
        features = [
            i / n_nodes,  # Normalized ROI index
            np.mean(roi_connectivity),  # Mean connectivity
            np.std(roi_connectivity),   # Std connectivity
            np.sum(roi_connectivity > threshold)  # Degree
        ]
        node_features.append(features)
    
    node_features = torch.FloatTensor(node_features)
    
    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)

def create_graph_from_timeseries(time_series, connectivity_matrix, threshold=0.3):
    """Create enhanced graph using both time series and connectivity information"""
    n_timepoints, n_rois = time_series.shape
    
    # Apply threshold and create adjacency matrix
    adj_matrix = (np.abs(connectivity_matrix) > threshold).astype(float)
    
    # Get edge indices
    edge_index = np.where(adj_matrix)
    edge_index = torch.LongTensor(np.vstack(edge_index))
    
    # Edge weights from connectivity
    edge_weights = connectivity_matrix[edge_index[0], edge_index[1]]
    edge_attr = torch.FloatTensor(edge_weights).unsqueeze(1)
    
    # Enhanced node features using time series statistics
    node_features = []
    for i in range(n_rois):
        roi_timeseries = time_series[:, i]
        roi_connectivity = connectivity_matrix[i]
        
        # Time series features
        mean_signal = np.mean(roi_timeseries)
        std_signal = np.std(roi_timeseries)
        max_signal = np.max(roi_timeseries)
        min_signal = np.min(roi_timeseries)
        
        # Connectivity features
        mean_connectivity = np.mean(roi_connectivity)
        std_connectivity = np.std(roi_connectivity)
        degree = np.sum(roi_connectivity > threshold)
        
        features = [
            i / n_rois,  # Normalized ROI index
            mean_signal,
            std_signal,
            max_signal,
            min_signal,
            mean_connectivity,
            std_connectivity,
            degree / n_rois  # Normalized degree
        ]
        node_features.append(features)
    
    node_features = torch.FloatTensor(node_features)
    
    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)

print("Enhanced data loading utilities defined successfully!")
print("Now supports both connectivity matrices and time series data!")

Enhanced data loading utilities defined successfully!
Now supports both connectivity matrices and time series data!


## Model 1: BrainGNN - ROI-aware Graph Neural Network

BrainGNN introduces ROI-aware convolutions and interpretable pooling specifically designed for brain connectivity analysis. It focuses on identifying disease-relevant biomarkers through specialized graph operations.

In [5]:
# BrainGNN Components

class ROIAwareConv(nn.Module):
    """ROI-aware convolution for brain connectivity"""
    
    def __init__(self, in_features, out_features, n_rois=400):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_rois = n_rois
        
        # ROI-specific transformations
        self.roi_linear = nn.Linear(in_features, out_features)
        self.roi_attention = nn.MultiheadAttention(out_features, num_heads=8, batch_first=True)
        
        # Graph convolution
        self.graph_conv = GCNConv(in_features, out_features)
        
        # Combination layer
        self.combine = nn.Linear(out_features * 2, out_features)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x, edge_index, batch=None):
        # ROI-aware processing
        roi_features = self.roi_linear(x)
        roi_attended, _ = self.roi_attention(roi_features.unsqueeze(0), 
                                           roi_features.unsqueeze(0), 
                                           roi_features.unsqueeze(0))
        roi_attended = roi_attended.squeeze(0)
        
        # Graph convolution
        graph_features = self.graph_conv(x, edge_index)
        
        # Combine features
        combined = torch.cat([roi_attended, graph_features], dim=1)
        output = self.combine(combined)
        output = F.relu(output)
        output = self.dropout(output)
        
        return output

class BiomarkerPooling(nn.Module):
    """Interpretable pooling for biomarker discovery"""
    
    def __init__(self, in_features, pool_ratio=0.5):
        super().__init__()
        self.pool_ratio = pool_ratio
        self.score_layer = nn.Linear(in_features, 1)
        self.feature_transform = nn.Linear(in_features, in_features)
        
    def forward(self, x, batch=None):
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Calculate importance scores
        scores = self.score_layer(x).squeeze(-1)
        
        # Get top-k nodes (biomarkers)
        batch_size = batch.max().item() + 1
        pooled_features = []
        biomarker_indices = []
        
        for i in range(batch_size):
            mask = (batch == i)
            node_scores = scores[mask]
            node_features = x[mask]
            
            k = max(1, int(self.pool_ratio * node_features.size(0)))
            top_k_indices = torch.topk(node_scores, k)[1]
            
            # Pool top-k features
            selected_features = node_features[top_k_indices]
            pooled = torch.mean(selected_features, dim=0)
            
            pooled_features.append(pooled)
            biomarker_indices.append(top_k_indices)
        
        pooled_features = torch.stack(pooled_features)
        return pooled_features, biomarker_indices

class BrainGNN(nn.Module):
    """BrainGNN for brain connectivity analysis"""
    
    def __init__(self, input_dim=4, hidden_dim=128, num_classes=2, n_rois=400):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.n_rois = n_rois
        
        # ROI-aware convolutions
        self.conv1 = ROIAwareConv(input_dim, hidden_dim, n_rois)
        self.conv2 = ROIAwareConv(hidden_dim, hidden_dim, n_rois)
        self.conv3 = ROIAwareConv(hidden_dim, hidden_dim, n_rois)
        
        # Biomarker pooling
        self.biomarker_pool = BiomarkerPooling(hidden_dim, pool_ratio=0.3)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # ROI-aware convolutions with residual connections
        x1 = self.conv1(x, edge_index, batch)
        x1 = self.bn1(x1)
        
        x2 = self.conv2(x1, edge_index, batch)
        x2 = self.bn2(x2)
        x2 = x2 + x1  # Residual connection
        
        x3 = self.conv3(x2, edge_index, batch)
        x3 = self.bn3(x3)
        x3 = x3 + x2  # Residual connection
        
        # Biomarker pooling
        pooled_features, biomarker_indices = self.biomarker_pool(x3, batch)
        
        # Classification
        output = self.classifier(pooled_features)
        
        return {
            'logits': output,
            'biomarker_indices': biomarker_indices,
            'features': pooled_features
        }

print("BrainGNN model defined successfully!")

BrainGNN model defined successfully!


## Model 2: Local-to-Global GNN (LG-GNN)

LG-GNN implements hierarchical learning that captures both local ROI interactions and global brain network patterns through a multi-scale architecture.

In [6]:
# Local-to-Global GNN Components

class LocalConvBlock(nn.Module):
    """Local convolution for capturing immediate ROI interactions"""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.conv = GCNConv(in_features, out_features)
        self.local_attention = nn.MultiheadAttention(out_features, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(out_features)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x, edge_index, batch=None):
        # Local graph convolution
        x_conv = self.conv(x, edge_index)
        x_conv = F.relu(x_conv)
        
        # Local attention
        if batch is not None:
            # Group by batch for attention
            batch_size = batch.max().item() + 1
            attended_features = []
            
            for i in range(batch_size):
                mask = (batch == i)
                batch_features = x_conv[mask].unsqueeze(0)
                attended, _ = self.local_attention(batch_features, batch_features, batch_features)
                attended_features.append(attended.squeeze(0))
            
            x_attended = torch.cat(attended_features, dim=0)
        else:
            x_attended, _ = self.local_attention(x_conv.unsqueeze(0), x_conv.unsqueeze(0), x_conv.unsqueeze(0))
            x_attended = x_attended.squeeze(0)
        
        # Residual connection and normalization
        output = self.norm(x_attended + x_conv)
        output = self.dropout(output)
        
        return output

class GlobalConvBlock(nn.Module):
    """Global convolution for capturing brain-wide patterns"""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.conv = GATConv(in_features, out_features, heads=8, concat=False)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.global_transform = nn.Linear(out_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x, edge_index, batch=None):
        # Global graph attention
        x_global = self.conv(x, edge_index)
        x_global = F.relu(x_global)
        
        # Global context aggregation
        if batch is not None:
            batch_size = batch.max().item() + 1
            global_context = []
            
            for i in range(batch_size):
                mask = (batch == i)
                batch_features = x_global[mask]
                # Average pooling for global context
                global_vec = torch.mean(batch_features, dim=0, keepdim=True)
                global_context.extend([global_vec] * batch_features.size(0))
            
            global_context = torch.cat(global_context, dim=0)
        else:
            global_context = torch.mean(x_global, dim=0, keepdim=True).expand_as(x_global)
        
        # Transform global context
        global_context = self.global_transform(global_context)
        
        # Combine local and global information
        output = self.norm(x_global + global_context)
        output = self.dropout(output)
        
        return output

class HierarchicalPooling(nn.Module):
    """Hierarchical pooling from local to global representations"""
    
    def __init__(self, in_features, reduction_factor=4):
        super().__init__()
        self.reduction_factor = reduction_factor
        self.local_pool = nn.Linear(in_features, in_features // reduction_factor)
        self.global_pool = nn.Linear(in_features, in_features // reduction_factor)
        self.combine = nn.Linear(in_features // reduction_factor * 2, in_features)
        self.attention = nn.MultiheadAttention(in_features, num_heads=4, batch_first=True)
        
    def forward(self, local_features, global_features, batch=None):
        # Pool local and global features
        local_pooled = self.local_pool(local_features)
        global_pooled = self.global_pool(global_features)
        
        # Combine features
        combined = torch.cat([local_pooled, global_pooled], dim=1)
        combined = self.combine(combined)
        combined = F.relu(combined)
        
        # Apply attention for final representation
        if batch is not None:
            batch_size = batch.max().item() + 1
            final_features = []
            
            for i in range(batch_size):
                mask = (batch == i)
                batch_features = combined[mask].unsqueeze(0)
                attended, _ = self.attention(batch_features, batch_features, batch_features)
                # Global pooling for final representation
                pooled = torch.mean(attended.squeeze(0), dim=0)
                final_features.append(pooled)
            
            return torch.stack(final_features)
        else:
            attended, _ = self.attention(combined.unsqueeze(0), combined.unsqueeze(0), combined.unsqueeze(0))
            return torch.mean(attended.squeeze(0), dim=0, keepdim=True)

class LGGNN(nn.Module):
    """Local-to-Global Graph Neural Network"""
    
    def __init__(self, input_dim=4, hidden_dim=128, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        # Local processing layers
        self.local_conv1 = LocalConvBlock(input_dim, hidden_dim)
        self.local_conv2 = LocalConvBlock(hidden_dim, hidden_dim)
        
        # Global processing layers
        self.global_conv1 = GlobalConvBlock(hidden_dim, hidden_dim)
        self.global_conv2 = GlobalConvBlock(hidden_dim, hidden_dim)
        
        # Hierarchical pooling
        self.hierarchical_pool = HierarchicalPooling(hidden_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 4, num_classes)
        )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Local feature extraction
        local_1 = self.local_conv1(x, edge_index, batch)
        local_2 = self.local_conv2(local_1, edge_index, batch)
        
        # Global feature extraction
        global_1 = self.global_conv1(local_2, edge_index, batch)
        global_2 = self.global_conv2(global_1, edge_index, batch)
        
        # Hierarchical pooling
        pooled_features = self.hierarchical_pool(local_2, global_2, batch)
        
        # Classification
        output = self.classifier(pooled_features)
        
        return {
            'logits': output,
            'local_features': local_2,
            'global_features': global_2,
            'pooled_features': pooled_features
        }

print("Local-to-Global GNN model defined successfully!")

Local-to-Global GNN model defined successfully!


## Model 3: Dynamic Multi-Site GCN (DG-DMSGCN)

DG-DMSGCN addresses multi-site variability in neuroimaging data through dynamic graph construction and site-adaptive mechanisms with temporal feature modeling.

In [7]:
# Dynamic Multi-Site GCN Components

class DynamicGraphConstruction(nn.Module):
    """Dynamic graph construction for adaptive connectivity"""
    
    def __init__(self, node_features, hidden_dim=64, top_k=20):
        super().__init__()
        self.node_features = node_features
        self.hidden_dim = hidden_dim
        self.top_k = top_k
        
        self.feature_transform = nn.Sequential(
            nn.Linear(node_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, node_features, batch=None):
        """Construct dynamic graphs based on node features"""
        # Transform node features
        transformed_features = self.feature_transform(node_features)
        
        if batch is None:
            batch = torch.zeros(node_features.size(0), dtype=torch.long, device=node_features.device)
        
        batch_size = batch.max().item() + 1
        edge_indices = []
        edge_weights = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_features = transformed_features[mask]
            n_nodes = batch_features.size(0)
            
            # Compute pairwise features
            edge_scores = []
            edges = []
            
            for i in range(n_nodes):
                for j in range(i + 1, n_nodes):
                    # Concatenate node features
                    edge_feature = torch.cat([batch_features[i], batch_features[j]], dim=0)
                    score = self.edge_predictor(edge_feature)
                    edge_scores.append(score)
                    edges.append([i, j])
            
            if edges:
                edge_scores = torch.stack(edge_scores).squeeze()
                edges = torch.tensor(edges, device=node_features.device)
                
                # Select top-k edges
                k = min(self.top_k, len(edges))
                top_k_indices = torch.topk(edge_scores, k)[1]
                selected_edges = edges[top_k_indices]
                selected_weights = edge_scores[top_k_indices]
                
                # Add reverse edges
                reverse_edges = torch.stack([selected_edges[:, 1], selected_edges[:, 0]], dim=1)
                all_edges = torch.cat([selected_edges, reverse_edges], dim=0)
                all_weights = torch.cat([selected_weights, selected_weights], dim=0)
                
                # Adjust indices for batch
                batch_offset = torch.sum(batch < b).item()
                all_edges = all_edges + batch_offset
                
                edge_indices.append(all_edges.t())
                edge_weights.append(all_weights)
        
        if edge_indices:
            edge_index = torch.cat(edge_indices, dim=1)
            edge_weight = torch.cat(edge_weights, dim=0)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long, device=node_features.device)
            edge_weight = torch.empty(0, device=node_features.device)
        
        return edge_index, edge_weight

class SiteAdaptiveLayer(nn.Module):
    """Site-adaptive normalization and feature transformation"""
    
    def __init__(self, features_dim, num_sites=10):
        super().__init__()
        self.features_dim = features_dim
        self.num_sites = num_sites
        
        # Site-specific parameters
        self.site_embeddings = nn.Embedding(num_sites, features_dim)
        self.site_scales = nn.Embedding(num_sites, features_dim)
        self.site_shifts = nn.Embedding(num_sites, features_dim)
        
        # Adaptive transformation
        self.adaptive_transform = nn.Sequential(
            nn.Linear(features_dim * 2, features_dim),
            nn.ReLU(),
            nn.Linear(features_dim, features_dim)
        )
        
        # Initialize parameters
        nn.init.ones_(self.site_scales.weight)
        nn.init.zeros_(self.site_shifts.weight)
        
    def forward(self, x, site_indices):
        """Apply site-adaptive transformation"""
        # Get site-specific parameters
        site_emb = self.site_embeddings(site_indices)
        site_scale = self.site_scales(site_indices)
        site_shift = self.site_shifts(site_indices)
        
        # Site-adaptive normalization
        normalized_x = x * site_scale + site_shift
        
        # Combine with site embeddings
        combined = torch.cat([normalized_x, site_emb], dim=-1)
        adapted_x = self.adaptive_transform(combined)
        
        return adapted_x

class TemporalFeatureExtractor(nn.Module):
    """Extract temporal features from connectivity patterns"""
    
    def __init__(self, input_dim, hidden_dim=64, sequence_length=10):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        
        # Temporal convolutions
        self.temporal_conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.temporal_conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim * 2, input_dim)
        
    def forward(self, x):
        """Extract temporal features"""
        batch_size, n_nodes, features = x.shape
        
        # Reshape for temporal processing
        x_temp = x.transpose(1, 2)  # (batch, features, nodes)
        
        # Temporal convolutions
        temp_conv1 = F.relu(self.temporal_conv1(x_temp))
        temp_conv2 = F.relu(self.temporal_conv2(temp_conv1))
        
        # Prepare for LSTM (treat nodes as sequence)
        temp_conv2 = temp_conv2.transpose(1, 2)  # (batch, nodes, hidden)
        
        # LSTM processing
        lstm_out, _ = self.lstm(temp_conv2)
        
        # Project back to original dimension
        temporal_features = self.output_proj(lstm_out)
        
        return temporal_features

class DGDMSGCN(nn.Module):
    """Dynamic Multi-Site Graph Convolutional Network"""
    
    def __init__(self, input_dim=4, hidden_dim=128, num_classes=2, num_sites=10):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_sites = num_sites
        
        # Dynamic graph construction
        self.dynamic_graph = DynamicGraphConstruction(input_dim, hidden_dim//2)
        
        # Site-adaptive layers
        self.site_adaptive1 = SiteAdaptiveLayer(hidden_dim, num_sites)
        self.site_adaptive2 = SiteAdaptiveLayer(hidden_dim, num_sites)
        
        # Temporal feature extraction
        self.temporal_extractor = TemporalFeatureExtractor(input_dim, hidden_dim//2)
        
        # Graph convolution layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Multi-scale pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Site encoding
        self.site_encoder = nn.Embedding(num_sites, 16)
        
    def encode_sites(self, sites):
        """Encode site names to indices"""
        site_names = list(set(sites))
        site_to_idx = {site: idx % self.num_sites for idx, site in enumerate(site_names)}
        return torch.tensor([site_to_idx.get(site, 0) for site in sites], 
                          dtype=torch.long, device=next(self.parameters()).device)
    
    def forward(self, data, sites):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Encode sites
        site_indices = self.encode_sites(sites)
        
        # Extract temporal features
        batch_size = batch.max().item() + 1
        temporal_features = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_x = x[mask].unsqueeze(0)
            temp_feat = self.temporal_extractor(batch_x)
            temporal_features.append(temp_feat.squeeze(0))
        
        x_temporal = torch.cat(temporal_features, dim=0)
        
        # Dynamic graph construction
        dynamic_edge_index, dynamic_edge_weight = self.dynamic_graph(x_temporal, batch)
        
        # Combine original and dynamic edges
        combined_edge_index = torch.cat([edge_index, dynamic_edge_index], dim=1)
        
        # Graph convolutions with site adaptation
        x1 = F.relu(self.conv1(x_temporal, combined_edge_index))
        x1_adapted = self.site_adaptive1(x1, site_indices[batch])
        
        x2 = F.relu(self.conv2(x1_adapted, combined_edge_index))
        x2_adapted = self.site_adaptive2(x2, site_indices[batch])
        
        x3 = F.relu(self.conv3(x2_adapted, combined_edge_index))
        
        # Multi-scale pooling
        pooled_features = []
        for b in range(batch_size):
            mask = (batch == b)
            batch_features = x3[mask].unsqueeze(0).transpose(1, 2)
            
            avg_pooled = self.global_pool(batch_features).squeeze(-1)
            max_pooled = self.max_pool(batch_features).squeeze(-1)
            
            combined_pool = torch.cat([avg_pooled, max_pooled], dim=1)
            pooled_features.append(combined_pool.squeeze(0))
        
        pooled_features = torch.stack(pooled_features)
        
        # Classification
        output = self.classifier(pooled_features)
        
        return {
            'logits': output,
            'dynamic_edges': dynamic_edge_index,
            'temporal_features': x_temporal,
            'adapted_features': x2_adapted
        }

print("Dynamic Multi-Site GCN model defined successfully!")

Dynamic Multi-Site GCN model defined successfully!


## Model 4: IFC-GNN - Interaction-based Functional Connectivity GNN

IFC-GNN models temporal interactions in functional connectivity through specialized convolutions and deep feature selection mechanisms.

In [8]:
# IFC-GNN Components

class InteractionConv(nn.Module):
    """Interaction-based convolution for functional connectivity"""
    
    def __init__(self, in_features, out_features, interaction_dim=32):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.interaction_dim = interaction_dim
        
        # Interaction encoding
        self.node_encoder = nn.Linear(in_features, interaction_dim)
        self.edge_encoder = nn.Linear(1, interaction_dim)  # Edge weights
        
        # Interaction computation
        self.interaction_net = nn.Sequential(
            nn.Linear(interaction_dim * 3, interaction_dim * 2),
            nn.ReLU(),
            nn.Linear(interaction_dim * 2, interaction_dim),
            nn.ReLU(),
            nn.Linear(interaction_dim, out_features)
        )
        
        # Message passing
        self.message_net = nn.Sequential(
            nn.Linear(out_features * 2, out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features)
        )
        
        # Update function
        self.update_net = GRUCell(out_features, out_features)
        
    def forward(self, x, edge_index, edge_attr=None):
        """Forward pass with interaction modeling"""
        row, col = edge_index
        
        # Encode nodes and edges
        node_emb = self.node_encoder(x)
        
        if edge_attr is None:
            edge_attr = torch.ones(edge_index.size(1), 1, device=x.device)
        edge_emb = self.edge_encoder(edge_attr)
        
        # Compute interactions
        source_nodes = node_emb[row]
        target_nodes = node_emb[col]
        
        # Combine source, target, and edge information
        interaction_input = torch.cat([source_nodes, target_nodes, edge_emb], dim=1)
        interactions = self.interaction_net(interaction_input)
        
        # Aggregate messages
        num_nodes = x.size(0)
        messages = torch.zeros(num_nodes, self.out_features, device=x.device)
        
        for i in range(edge_index.size(1)):
            src, tgt = row[i], col[i]
            message = self.message_net(torch.cat([interactions[i], node_emb[src]], dim=0).unsqueeze(0))
            messages[tgt] += message.squeeze(0)
        
        # Update node representations
        h_prev = torch.zeros(num_nodes, self.out_features, device=x.device)
        updated_nodes = self.update_net(messages, h_prev)
        
        return updated_nodes

class TemporalInteractionBlock(nn.Module):
    """Temporal interaction modeling for functional connectivity"""
    
    def __init__(self, hidden_dim, num_time_steps=5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_time_steps = num_time_steps
        
        # Temporal encoding
        self.temporal_encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        # Interaction LSTM
        self.interaction_lstm = nn.LSTM(
            hidden_dim, hidden_dim, 
            num_layers=2, 
            batch_first=True, 
            bidirectional=True
        )
        
        # Attention mechanism
        self.temporal_attention = nn.MultiheadAttention(
            hidden_dim * 2, num_heads=8, batch_first=True
        )
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim * 2, hidden_dim)
        
    def forward(self, x, batch=None):
        """Model temporal interactions"""
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        batch_size = batch.max().item() + 1
        temporal_features = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_x = x[mask]  # (num_nodes, hidden_dim)
            
            # Create temporal sequence
            temporal_seq = []
            for t in range(self.num_time_steps):
                # Add temporal encoding
                temp_encoded = self.temporal_encoder(batch_x)
                temporal_seq.append(temp_encoded)
            
            temporal_seq = torch.stack(temporal_seq, dim=1)  # (num_nodes, time_steps, hidden_dim)
            
            # LSTM processing
            lstm_out, _ = self.interaction_lstm(temporal_seq)  # (num_nodes, time_steps, hidden_dim*2)
            
            # Temporal attention
            attended, _ = self.temporal_attention(lstm_out, lstm_out, lstm_out)
            
            # Aggregate over time
            temporal_agg = torch.mean(attended, dim=1)  # (num_nodes, hidden_dim*2)
            
            # Project to output dimension
            output = self.output_proj(temporal_agg)  # (num_nodes, hidden_dim)
            temporal_features.append(output)
        
        return torch.cat(temporal_features, dim=0)

class DeepFeatureSelection(nn.Module):
    """Deep feature selection for connectivity patterns"""
    
    def __init__(self, input_dim, selection_ratio=0.5):
        super().__init__()
        self.input_dim = input_dim
        self.selection_ratio = selection_ratio
        
        # Feature importance network
        self.importance_net = nn.Sequential(
            nn.Linear(input_dim, input_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(input_dim * 2, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim),
            nn.Sigmoid()
        )
        
        # Feature transformation network
        self.transform_net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim)
        )
        
        # Gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(input_dim * 2, input_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        """Select and transform features"""
        # Compute feature importance
        importance_scores = self.importance_net(x)
        
        # Transform features
        transformed_features = self.transform_net(x)
        
        # Gating
        gate_input = torch.cat([x, transformed_features], dim=-1)
        gate_values = self.gate(gate_input)
        
        # Apply gating and importance weighting
        selected_features = transformed_features * gate_values * importance_scores
        
        # Add residual connection
        output = selected_features + x * (1 - gate_values)
        
        return output, importance_scores

class IFCGNN(nn.Module):
    """Interaction-based Functional Connectivity GNN"""
    
    def __init__(self, input_dim=4, hidden_dim=128, num_classes=2, interaction_dim=64):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.interaction_dim = interaction_dim
        
        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        # Interaction convolutions
        self.interaction_conv1 = InteractionConv(hidden_dim, hidden_dim, interaction_dim)
        self.interaction_conv2 = InteractionConv(hidden_dim, hidden_dim, interaction_dim)
        self.interaction_conv3 = InteractionConv(hidden_dim, hidden_dim, interaction_dim)
        
        # Temporal interaction modeling
        self.temporal_block = TemporalInteractionBlock(hidden_dim)
        
        # Deep feature selection
        self.feature_selection = DeepFeatureSelection(hidden_dim)
        
        # Self-attention pooling
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 4, num_classes)
        )
        
        # Normalization layers
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        # Input projection
        x = self.input_proj(x)
        
        # Interaction convolutions with residual connections
        x1 = self.interaction_conv1(x, edge_index, edge_attr)
        x1 = self.norm1(x1 + x)
        
        x2 = self.interaction_conv2(x1, edge_index, edge_attr)
        x2 = self.norm2(x2 + x1)
        
        x3 = self.interaction_conv3(x2, edge_index, edge_attr)
        x3 = self.norm3(x3 + x2)
        
        # Temporal interaction modeling
        temporal_features = self.temporal_block(x3, batch)
        
        # Deep feature selection
        selected_features, importance_scores = self.feature_selection(temporal_features)
        
        # Self-attention pooling
        batch_size = batch.max().item() + 1
        pooled_features = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_features = selected_features[mask].unsqueeze(0)
            
            # Self-attention
            attended, attention_weights = self.self_attention(
                batch_features, batch_features, batch_features
            )
            
            # Global pooling
            pooled = torch.mean(attended.squeeze(0), dim=0)
            pooled_features.append(pooled)
        
        pooled_features = torch.stack(pooled_features)
        
        # Classification
        output = self.classifier(pooled_features)
        
        return {
            'logits': output,
            'importance_scores': importance_scores,
            'temporal_features': temporal_features,
            'selected_features': selected_features,
            'pooled_features': pooled_features
        }

print("IFC-GNN model defined successfully!")

IFC-GNN model defined successfully!


## Model 5: RAGNN - Region-Aware Graph Neural Network

RAGNN incorporates hemispheric asymmetry learning and EEG-based analysis for comprehensive brain connectivity modeling with region-specific processing.

In [9]:
# RAGNN Components

class HemisphericAsymmetryModule(nn.Module):
    """Module for learning hemispheric asymmetry patterns"""
    
    def __init__(self, hidden_dim, n_rois=400):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_rois = n_rois
        self.hemisphere_size = n_rois // 2
        
        # Hemisphere-specific encoders
        self.left_encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.right_encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Asymmetry computation
        self.asymmetry_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Tanh()
        )
        
        # Cross-hemispheric attention
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
        
        # Integration layer
        self.integration = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
    def forward(self, x, batch=None):
        """Learn hemispheric asymmetry patterns"""
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        batch_size = batch.max().item() + 1
        asymmetry_features = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_x = x[mask]
            
            # Assume first half is left hemisphere, second half is right
            n_nodes = batch_x.size(0)
            left_nodes = batch_x[:n_nodes//2]
            right_nodes = batch_x[n_nodes//2:]
            
            # Hemisphere-specific encoding
            left_encoded = self.left_encoder(left_nodes)
            right_encoded = self.right_encoder(right_nodes)
            
            # Cross-hemispheric attention
            left_query = left_encoded.unsqueeze(0)
            right_key_value = right_encoded.unsqueeze(0)
            left_attended, _ = self.cross_attention(left_query, right_key_value, right_key_value)
            left_attended = left_attended.squeeze(0)
            
            right_query = right_encoded.unsqueeze(0)
            left_key_value = left_encoded.unsqueeze(0)
            right_attended, _ = self.cross_attention(right_query, left_key_value, left_key_value)
            right_attended = right_attended.squeeze(0)
            
            # Compute asymmetry scores
            asymmetry_input = torch.cat([left_encoded, right_encoded], dim=1)
            asymmetry_scores = self.asymmetry_net(asymmetry_input)
            
            # Integrate information
            left_integrated = self.integration(torch.cat([left_encoded, left_attended, asymmetry_scores.expand_as(left_encoded)], dim=1))
            right_integrated = self.integration(torch.cat([right_encoded, right_attended, asymmetry_scores.expand_as(right_encoded)], dim=1))
            
            # Combine hemispheres
            integrated_features = torch.cat([left_integrated, right_integrated], dim=0)
            asymmetry_features.append(integrated_features)
        
        return torch.cat(asymmetry_features, dim=0)

class RegionSpecificConv(nn.Module):
    """Region-specific convolution for different brain areas"""
    
    def __init__(self, in_features, out_features, num_regions=8):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_regions = num_regions
        
        # Region-specific convolutions
        self.region_convs = nn.ModuleList([
            GCNConv(in_features, out_features) for _ in range(num_regions)
        ])
        
        # Region classifier
        self.region_classifier = nn.Sequential(
            nn.Linear(in_features, num_regions),
            nn.Softmax(dim=-1)
        )
        
        # Adaptive combination
        self.combination_net = nn.Sequential(
            nn.Linear(out_features * num_regions, out_features * 2),
            nn.ReLU(),
            nn.Linear(out_features * 2, out_features)
        )
        
    def forward(self, x, edge_index, batch=None):
        """Apply region-specific convolutions"""
        # Classify nodes into regions
        region_probs = self.region_classifier(x)
        
        # Apply region-specific convolutions
        region_outputs = []
        for i, conv in enumerate(self.region_convs):
            region_output = conv(x, edge_index)
            region_outputs.append(region_output)
        
        # Weighted combination based on region probabilities
        weighted_outputs = []
        for i, output in enumerate(region_outputs):
            weight = region_probs[:, i:i+1]
            weighted_output = output * weight
            weighted_outputs.append(weighted_output)
        
        # Combine all regions
        combined_output = torch.cat(weighted_outputs, dim=1)
        final_output = self.combination_net(combined_output)
        
        return final_output, region_probs

class EEGInspiredPooling(nn.Module):
    """EEG-inspired pooling for frequency band analysis"""
    
    def __init__(self, hidden_dim, num_freq_bands=5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_freq_bands = num_freq_bands
        
        # Frequency band decomposition
        self.freq_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ) for _ in range(num_freq_bands)
        ])
        
        # Band-specific attention
        self.band_attention = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
            for _ in range(num_freq_bands)
        ])
        
        # Frequency integration
        self.freq_integration = nn.Sequential(
            nn.Linear(hidden_dim * num_freq_bands, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
    def forward(self, x, batch=None):
        """EEG-inspired pooling with frequency bands"""
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        batch_size = batch.max().item() + 1
        freq_features = []
        
        for b in range(batch_size):
            mask = (batch == b)
            batch_x = x[mask].unsqueeze(0)
            
            # Decompose into frequency bands
            band_features = []
            for i, (transform, attention) in enumerate(zip(self.freq_transforms, self.band_attention)):
                # Transform to frequency domain
                freq_transformed = transform(batch_x.squeeze(0))
                freq_transformed = freq_transformed.unsqueeze(0)
                
                # Apply attention within band
                attended, _ = attention(freq_transformed, freq_transformed, freq_transformed)
                
                # Pool within band
                band_pooled = torch.mean(attended.squeeze(0), dim=0)
                band_features.append(band_pooled)
            
            # Integrate frequency bands
            integrated_freq = torch.cat(band_features, dim=0)
            integrated_output = self.freq_integration(integrated_freq)
            freq_features.append(integrated_output)
        
        return torch.stack(freq_features)

class RAGNN(nn.Module):
    """Region-Aware Graph Neural Network with Hemispheric Asymmetry"""
    
    def __init__(self, input_dim=4, hidden_dim=128, num_classes=2, num_regions=8, num_freq_bands=5):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_regions = num_regions
        self.num_freq_bands = num_freq_bands
        
        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        # Region-specific convolutions
        self.region_conv1 = RegionSpecificConv(hidden_dim, hidden_dim, num_regions)
        self.region_conv2 = RegionSpecificConv(hidden_dim, hidden_dim, num_regions)
        
        # Hemispheric asymmetry module
        self.hemisphere_module = HemisphericAsymmetryModule(hidden_dim)
        
        # Standard convolution layers
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, hidden_dim)
        
        # EEG-inspired pooling
        self.eeg_pooling = EEGInspiredPooling(hidden_dim, num_freq_bands)
        
        # Multi-head attention for final integration
        self.final_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 4, num_classes)
        )
        
        # Normalization layers
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.norm4 = nn.LayerNorm(hidden_dim)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Input projection
        x = self.input_proj(x)
        
        # Region-specific convolutions
        x1, region_probs1 = self.region_conv1(x, edge_index, batch)
        x1 = self.norm1(x1 + x)
        
        x2, region_probs2 = self.region_conv2(x1, edge_index, batch)
        x2 = self.norm2(x2 + x1)
        
        # Hemispheric asymmetry learning
        x_asym = self.hemisphere_module(x2, batch)
        x_asym = self.norm3(x_asym + x2)
        
        # Standard convolutions
        x3 = F.relu(self.conv3(x_asym, edge_index))
        x3 = self.norm4(x3 + x_asym)
        
        x4 = F.relu(self.conv4(x3, edge_index))
        
        # EEG-inspired pooling
        pooled_features = self.eeg_pooling(x4, batch)
        
        # Final attention
        attended_features, attention_weights = self.final_attention(
            pooled_features.unsqueeze(1), 
            pooled_features.unsqueeze(1), 
            pooled_features.unsqueeze(1)
        )
        final_features = attended_features.squeeze(1)
        
        # Classification
        output = self.classifier(final_features)
        
        return {
            'logits': output,
            'region_probs1': region_probs1,
            'region_probs2': region_probs2,
            'hemispheric_features': x_asym,
            'attention_weights': attention_weights,
            'final_features': final_features
        }

print("RAGNN model defined successfully!")

RAGNN model defined successfully!


## Training and Evaluation Utilities

Comprehensive training framework with gradient accumulation, early stopping, learning rate scheduling, and model evaluation utilities.

In [10]:
# Training and Evaluation Framework

class EarlyStopping:
    """Early stopping utility with patience and best model saving"""
    
    def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        return self.counter >= self.patience
    
    def save_checkpoint(self, model):
        if self.restore_best_weights:
            self.best_weights = {name: param.clone() for name, param in model.named_parameters()}
    
    def restore_best_weights_fn(self, model):
        if self.best_weights:
            for name, param in model.named_parameters():
                param.data.copy_(self.best_weights[name])

def create_data_loaders(dataset, batch_size=32, train_ratio=0.8, val_ratio=0.1):
    """Create train/validation/test data loaders"""
    n_samples = len(dataset)
    indices = list(range(n_samples))
    np.random.shuffle(indices)
    
    train_split = int(train_ratio * n_samples)
    val_split = int((train_ratio + val_ratio) * n_samples)
    
    train_indices = indices[:train_split]
    val_indices = indices[train_split:val_split]
    test_indices = indices[val_split:]
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_indices)
    
    def collate_fn(batch):
        """Custom collate function for graph data with time series support"""
        graphs = []
        labels = []
        demographics = []
        sites = []
        subjects = []
        
        for item in batch:
            # Create enhanced graph using both time series and connectivity
            if 'time_series' in item:
                graph = create_graph_from_timeseries(
                    item['time_series'].numpy(),
                    item['connectivity'].numpy(),
                    threshold=0.3
                )
            else:
                # Fallback to connectivity-only graph
                graph = create_graph_from_connectivity(item['connectivity'].numpy())
            
            graphs.append(graph)
            labels.append(item['label'])
            demographics.append(item['demographics'])
            sites.append(item['site'])
            subjects.append(item['subject'])
        
        # Batch graphs
        batched_graphs = Batch.from_data_list(graphs)
        labels = torch.cat(labels, dim=0)
        demographics = torch.stack(demographics, dim=0)
        
        return {
            'graphs': batched_graphs,
            'labels': labels,
            'demographics': demographics,
            'sites': sites,
            'subjects': subjects
        }
    
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, collate_fn=collate_fn)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, collate_fn=collate_fn)
    
    return train_loader, val_loader, test_loader

def train_model(model, train_loader, val_loader, num_epochs=100, learning_rate=0.001, 
                accumulation_steps=4, device='cuda'):
    """Train model with gradient accumulation and early stopping"""
    
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=15, min_delta=0.001)
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    print(f"Training on {device} for {num_epochs} epochs...")
    print(f"Gradient accumulation steps: {accumulation_steps}")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        optimizer.zero_grad()
        
        for batch_idx, batch in enumerate(train_loader):
            graphs = batch['graphs'].to(device)
            labels = batch['labels'].to(device)
            sites = batch['sites']
            
            # Forward pass
            if hasattr(model, 'encode_sites'):  # For DGDMSGCN
                outputs = model(graphs, sites)
            else:
                outputs = model(graphs)
            
            logits = outputs['logits']
            loss = criterion(logits, labels) / accumulation_steps
            
            # Backward pass
            loss.backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            # Statistics
            train_loss += loss.item() * accumulation_steps
            _, predicted = torch.max(logits.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                graphs = batch['graphs'].to(device)
                labels = batch['labels'].to(device)
                sites = batch['sites']
                
                if hasattr(model, 'encode_sites'):  # For DGDMSGCN
                    outputs = model(graphs, sites)
                else:
                    outputs = model(graphs)
                
                logits = outputs['logits']
                loss = criterion(logits, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate averages
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        
        # Store metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
            print('-' * 50)
        
        # Early stopping
        if early_stopping(avg_val_loss, model):
            print(f'Early stopping triggered at epoch {epoch+1}')
            break
    
    # Restore best weights
    early_stopping.restore_best_weights_fn(model)
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'final_epoch': epoch + 1
    }

def evaluate_model(model, test_loader, device='cuda'):
    """Comprehensive model evaluation"""
    model.eval()
    model = model.to(device)
    
    all_predictions = []
    all_labels = []
    all_probabilities = []
    test_loss = 0.0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in test_loader:
            graphs = batch['graphs'].to(device)
            labels = batch['labels'].to(device)
            sites = batch['sites']
            
            if hasattr(model, 'encode_sites'):  # For DGDMSGCN
                outputs = model(graphs, sites)
            else:
                outputs = model(graphs)
            
            logits = outputs['logits']
            loss = criterion(logits, labels)
            test_loss += loss.item()
            
            # Get predictions and probabilities
            probabilities = F.softmax(logits, dim=1)
            _, predicted = torch.max(logits, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    # Calculate metrics
    avg_test_loss = test_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    
    # ROC AUC for binary classification
    if len(np.unique(all_labels)) == 2:
        probabilities_class1 = np.array(all_probabilities)[:, 1]
        auc_score = roc_auc_score(all_labels, probabilities_class1)
    else:
        auc_score = None
    
    # Classification report
    class_report = classification_report(all_labels, all_predictions, 
                                       target_names=['Control', 'ASD'], 
                                       output_dict=True)
    
    print("=== Model Evaluation Results ===")
    print(f"Test Loss: {avg_test_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    if auc_score:
        print(f"ROC AUC Score: {auc_score:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_predictions, target_names=['Control', 'ASD']))
    
    return {
        'test_loss': avg_test_loss,
        'accuracy': accuracy,
        'auc_score': auc_score,
        'predictions': all_predictions,
        'labels': all_labels,
        'probabilities': all_probabilities,
        'classification_report': class_report
    }

def plot_training_history(history):
    """Plot training history"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training and validation loss
    axes[0, 0].plot(history['train_losses'], label='Training Loss', color='blue')
    axes[0, 0].plot(history['val_losses'], label='Validation Loss', color='red')
    axes[0, 0].set_title('Model Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Training and validation accuracy
    axes[0, 1].plot(history['train_accuracies'], label='Training Accuracy', color='blue')
    axes[0, 1].plot(history['val_accuracies'], label='Validation Accuracy', color='red')
    axes[0, 1].set_title('Model Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Loss difference
    loss_diff = np.array(history['val_losses']) - np.array(history['train_losses'])
    axes[1, 0].plot(loss_diff, color='purple')
    axes[1, 0].set_title('Validation - Training Loss')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss Difference')
    axes[1, 0].grid(True)
    axes[1, 0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # Accuracy difference
    acc_diff = np.array(history['val_accuracies']) - np.array(history['train_accuracies'])
    axes[1, 1].plot(acc_diff, color='orange')
    axes[1, 1].set_title('Validation - Training Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy Difference (%)')
    axes[1, 1].grid(True)
    axes[1, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.show()

print("Training and evaluation utilities defined successfully!")

Training and evaluation utilities defined successfully!


## Model Comparison and Execution

Instantiate all five models and demonstrate their usage with the ABIDE dataset.

In [11]:
# Model Instantiation and Comparison

def initialize_models(input_dim=8, hidden_dim=128, num_classes=2):
    """Initialize all five brain GNN models with updated input dimension"""
    
    models = {
        'BrainGNN': BrainGNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            n_rois=400
        ),
        
        'LG-GNN': LGGNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes
        ),
        
        'DG-DMSGCN': DGDMSGCN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            num_sites=10
        ),
        
        'IFC-GNN': IFCGNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            interaction_dim=64
        ),
        
        'RAGNN': RAGNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            num_regions=8,
            num_freq_bands=5
        )
    }
    
    return models

def count_parameters(model):
    """Count total and trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def compare_models(models):
    """Compare model architectures and parameters"""
    print("=== Model Comparison ===")
    print(f"{'Model':<12} {'Total Params':<15} {'Trainable':<15} {'Memory (MB)':<12}")
    print("-" * 60)
    
    for name, model in models.items():
        total, trainable = count_parameters(model)
        # Rough memory estimation (assuming float32)
        memory_mb = total * 4 / (1024 * 1024)
        
        print(f"{name:<12} {total:<15,} {trainable:<15,} {memory_mb:<12.2f}")
    
    print("-" * 60)

# Initialize all models
print("Initializing all five Brain GNN models...")
models = initialize_models()

# Compare models
compare_models(models)

# Test model forward passes with dummy data
print("\n=== Testing Model Forward Passes ===")

# Create dummy graph data with updated feature dimension
dummy_x = torch.randn(100, 8)  # 100 nodes, 8 features (updated)
dummy_edge_index = torch.randint(0, 100, (2, 200))  # 200 edges
dummy_edge_attr = torch.randn(200, 1)  # Edge attributes
dummy_batch = torch.zeros(100, dtype=torch.long)  # Single graph

dummy_data = Data(x=dummy_x, edge_index=dummy_edge_index, edge_attr=dummy_edge_attr, batch=dummy_batch)
dummy_sites = ['Site1'] * 1  # For DG-DMSGCN

for name, model in models.items():
    try:
        model.eval()
        with torch.no_grad():
            if name == 'DG-DMSGCN':
                output = model(dummy_data, dummy_sites)
            else:
                output = model(dummy_data)
        
        logits = output['logits']
        print(f"{name}: Output shape = {logits.shape} ✓")
        
    except Exception as e:
        print(f"{name}: Error - {str(e)} ✗")

print("\nAll models initialized and tested successfully!")

Initializing all five Brain GNN models...
=== Model Comparison ===
Model        Total Params    Trainable       Memory (MB) 
------------------------------------------------------------
BrainGNN     390,979         390,979         1.49        
LG-GNN       543,330         543,330         2.07        
DG-DMSGCN    276,715         276,715         1.06        
IFC-GNN      1,842,594       1,842,594       7.03        
RAGNN        1,965,619       1,965,619       7.50        
------------------------------------------------------------

=== Testing Model Forward Passes ===
BrainGNN: Output shape = torch.Size([1, 2]) ✓
LG-GNN: Output shape = torch.Size([1, 2]) ✓
=== Model Comparison ===
Model        Total Params    Trainable       Memory (MB) 
------------------------------------------------------------
BrainGNN     390,979         390,979         1.49        
LG-GNN       543,330         543,330         2.07        
DG-DMSGCN    276,715         276,715         1.06        
IFC-GNN      1,84

In [12]:
# Load ABIDE Dataset and Prepare for Training

# Initialize dataset
print("Loading ABIDE dataset...")
data_dir = "/home/moew/Documents/ABIDE/abide_data"
phenotypic_file = "/home/moew/Documents/ABIDE/Phenotypic_V1_0b_preprocessed1.csv"

try:
    dataset = ABIDEDataset(data_dir, phenotypic_file)
    print(f"Dataset loaded successfully with {len(dataset)} subjects")
    
    # Create data loaders
    print("Creating data loaders...")
    train_loader, val_loader, test_loader = create_data_loaders(
        dataset, batch_size=8, train_ratio=0.7, val_ratio=0.15
    )
    
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Test data loading
    print("\nTesting data loading...")
    sample_batch = next(iter(train_loader))
    print(f"Sample batch graphs shape: {sample_batch['graphs'].x.shape}")
    print(f"Sample batch labels shape: {sample_batch['labels'].shape}")
    print(f"Sample batch sites: {sample_batch['sites'][:3]}...")
    
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Creating synthetic data for demonstration...")
    
    # Create synthetic dataset for demonstration
    class SyntheticDataset(Dataset):
        def __init__(self, n_samples=100):
            self.n_samples = n_samples
            
        def __len__(self):
            return self.n_samples
        
        def __getitem__(self, idx):
            # Generate random connectivity matrix
            connectivity = np.random.randn(200, 200) * 0.3
            connectivity = (connectivity + connectivity.T) / 2
            np.fill_diagonal(connectivity, 1.0)
            
            return {
                'connectivity': torch.FloatTensor(connectivity),
                'label': torch.LongTensor([idx % 2]),  # Binary labels
                'demographics': torch.FloatTensor([25.0 + np.random.randn(), 1]),
                'site': f'Site_{idx % 5}',
                'subject': f'Subject_{idx:03d}'
            }
    
    print("Using synthetic dataset for demonstration...")
    dataset = SyntheticDataset(n_samples=80)
    train_loader, val_loader, test_loader = create_data_loaders(
        dataset, batch_size=4, train_ratio=0.7, val_ratio=0.15
    )
    
    print(f"Synthetic dataset created with {len(dataset)} samples")
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")

print("\nData loading completed successfully!")

Loading ABIDE dataset...
Loading ABIDE time series data...
Loaded phenotypic data with 1112 subjects
Found 882 time series files
Successfully loaded 882 subjects
Time series shape: (176, 392)
Connectivity matrix shape: (392, 392)
Label distribution: [882]
Dataset loaded successfully with 882 subjects
Creating data loaders...
Train batches: 78
Validation batches: 17
Test batches: 17

Testing data loading...
Successfully loaded 882 subjects
Time series shape: (176, 392)
Connectivity matrix shape: (392, 392)
Label distribution: [882]
Dataset loaded successfully with 882 subjects
Creating data loaders...
Train batches: 78
Validation batches: 17
Test batches: 17

Testing data loading...
Sample batch graphs shape: torch.Size([3136, 8])
Sample batch labels shape: torch.Size([8])
Sample batch sites: ['UCLA', 'Leuven', 'Olin']...

Data loading completed successfully!
Sample batch graphs shape: torch.Size([3136, 8])
Sample batch labels shape: torch.Size([8])
Sample batch sites: ['UCLA', 'Leuven'

In [13]:
# Training Demonstration - Select and Train a Model

# Select model for training demonstration
selected_model_name = 'BrainGNN'  # Change this to test different models
selected_model = models[selected_model_name]

print(f"=== Training {selected_model_name} ===")
print(f"Model architecture: {selected_model}")
print(f"Training on device: {device}")

# Training hyperparameters
training_config = {
    'num_epochs': 50,  # Reduced for demonstration
    'learning_rate': 0.001,
    'accumulation_steps': 2,
    'device': device
}

print(f"Training configuration: {training_config}")

# Train the selected model
try:
    print(f"\nStarting training for {selected_model_name}...")
    training_history = train_model(
        selected_model, 
        train_loader, 
        val_loader, 
        **training_config
    )
    
    print(f"\nTraining completed in {training_history['final_epoch']} epochs")
    print(f"Best validation accuracy: {max(training_history['val_accuracies']):.2f}%")
    
    # Plot training history
    plot_training_history(training_history)
    
    # Evaluate on test set
    print("\n=== Test Set Evaluation ===")
    test_results = evaluate_model(selected_model, test_loader, device)
    
    # Save the trained model
    model_save_path = f'{selected_model_name.lower()}_best_model.pth'
    torch.save(selected_model.state_dict(), model_save_path)
    print(f"\nModel saved to {model_save_path}")
    
except Exception as e:
    print(f"Training error: {e}")
    import traceback
    traceback.print_exc()

print(f"\n{selected_model_name} training and evaluation completed!")

=== Training BrainGNN ===
Model architecture: BrainGNN(
  (conv1): ROIAwareConv(
    (roi_linear): Linear(in_features=8, out_features=128, bias=True)
    (roi_attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (graph_conv): GCNConv(8, 128)
    (combine): Linear(in_features=256, out_features=128, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (conv2): ROIAwareConv(
    (roi_linear): Linear(in_features=128, out_features=128, bias=True)
    (roi_attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (graph_conv): GCNConv(128, 128)
    (combine): Linear(in_features=256, out_features=128, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (conv3): ROIAwareConv(
    (roi_linear): Linear(in_features=128, out_features=128, bias=True)
    (roi_attention): MultiheadAttention(
      (out_proj): NonDynami

KeyboardInterrupt: 

In [14]:
# Optimized Training Demonstration - Precomputed Graphs

# First, let's create an optimized dataset that precomputes all graphs
class OptimizedABIDEDataset(Dataset):
    """Optimized ABIDE Dataset with precomputed graphs"""
    
    def __init__(self, original_dataset, threshold=0.3):
        self.threshold = threshold
        self.graphs = []
        self.labels = []
        self.demographics = []
        self.sites = []
        self.subjects = []
        
        print("Precomputing graphs for optimized training...")
        print(f"Processing {len(original_dataset)} samples...")
        
        # Precompute all graphs
        for i in range(len(original_dataset)):
            if i % 100 == 0:
                print(f"Processing sample {i}/{len(original_dataset)}")
                
            item = original_dataset[i]
            
            # Create enhanced graph using both time series and connectivity
            graph = create_graph_from_timeseries(
                item['time_series'].numpy(),
                item['connectivity'].numpy(),
                threshold=threshold
            )
            
            self.graphs.append(graph)
            self.labels.append(item['label'])
            self.demographics.append(item['demographics'])
            self.sites.append(item['site'])
            self.subjects.append(item['subject'])
        
        print("Graph precomputation completed!")
        
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return {
            'graph': self.graphs[idx],
            'label': self.labels[idx],
            'demographics': self.demographics[idx],
            'site': self.sites[idx],
            'subject': self.subjects[idx]
        }

def create_optimized_data_loaders(dataset, batch_size=8, train_ratio=0.8, val_ratio=0.1):
    """Create optimized data loaders with precomputed graphs"""
    n_samples = len(dataset)
    indices = list(range(n_samples))
    np.random.shuffle(indices)
    
    train_split = int(train_ratio * n_samples)
    val_split = int((train_ratio + val_ratio) * n_samples)
    
    train_indices = indices[:train_split]
    val_indices = indices[train_split:val_split]
    test_indices = indices[val_split:]
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_indices)
    
    def optimized_collate_fn(batch):
        """Optimized collate function for precomputed graphs"""
        graphs = [item['graph'] for item in batch]
        labels = torch.cat([item['label'] for item in batch], dim=0)
        demographics = torch.stack([item['demographics'] for item in batch], dim=0)
        sites = [item['site'] for item in batch]
        subjects = [item['subject'] for item in batch]
        
        # Batch graphs efficiently
        batched_graphs = Batch.from_data_list(graphs)
        
        return {
            'graphs': batched_graphs,
            'labels': labels,
            'demographics': demographics,
            'sites': sites,
            'subjects': subjects
        }
    
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, 
                             collate_fn=optimized_collate_fn, num_workers=0, pin_memory=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, 
                           collate_fn=optimized_collate_fn, num_workers=0, pin_memory=True)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, 
                            collate_fn=optimized_collate_fn, num_workers=0, pin_memory=True)
    
    return train_loader, val_loader, test_loader

# Create optimized dataset
print("Creating optimized dataset with precomputed graphs...")
optimized_dataset = OptimizedABIDEDataset(dataset, threshold=0.3)

# Create optimized data loaders with smaller batch size for faster iteration
print("Creating optimized data loaders...")
opt_train_loader, opt_val_loader, opt_test_loader = create_optimized_data_loaders(
    optimized_dataset, batch_size=8, train_ratio=0.8, val_ratio=0.1
)

print(f"Optimized train batches: {len(opt_train_loader)}")
print(f"Optimized validation batches: {len(opt_val_loader)}")
print(f"Optimized test batches: {len(opt_test_loader)}")

print("✅ Optimized data loading setup completed!")

Creating optimized dataset with precomputed graphs...
Precomputing graphs for optimized training...
Processing 882 samples...
Processing sample 0/882
Processing sample 100/882
Processing sample 200/882
Processing sample 300/882
Processing sample 400/882
Processing sample 500/882
Processing sample 600/882
Processing sample 700/882
Processing sample 800/882
Graph precomputation completed!
Creating optimized data loaders...
Optimized train batches: 89
Optimized validation batches: 11
Optimized test batches: 12
✅ Optimized data loading setup completed!


In [16]:
# Fast Training Demonstration with Optimized Data Loading

# Select model for fast training demonstration
selected_model_name = 'BrainGNN'  # Start with BrainGNN as it worked well
selected_model = models[selected_model_name].to(device)

print(f"=== Fast Training Demo: {selected_model_name} ===")
print(f"Model parameters: {sum(p.numel() for p in selected_model.parameters()):,}")
print(f"Training on device: {device}")

# Optimized training hyperparameters for fast demonstration
fast_training_config = {
    'num_epochs': 5,  # Just 5 epochs for quick demo
    'learning_rate': 0.01,  # Higher learning rate for faster convergence
    'device': device
}

print(f"Fast training configuration: {fast_training_config}")

# Simple training function optimized for speed
def fast_train_model(model, train_loader, val_loader, num_epochs=5, learning_rate=0.01, device='cuda'):
    """Fast training function optimized for demonstration"""
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    print(f"\n🚀 Starting fast training on {device} for {num_epochs} epochs...")
    print(f"Training batches: {len(train_loader)}, Validation batches: {len(val_loader)}")
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, batch in enumerate(train_loader):
            graphs = batch['graphs'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(graphs)
            logits = outputs['logits']
            loss = criterion(logits, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            # Print progress every 20 batches
            if (batch_idx + 1) % 20 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}: Loss = {loss.item():.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                graphs = batch['graphs'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(graphs)
                logits = outputs['logits']
                loss = criterion(logits, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate metrics
        epoch_time = time.time() - start_time
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        
        print(f"Epoch {epoch+1}/{num_epochs} ({epoch_time:.1f}s):")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print("-" * 50)
    
    return {
        'final_train_acc': train_acc,
        'final_val_acc': val_acc,
        'final_train_loss': avg_train_loss,
        'final_val_loss': avg_val_loss
    }

# Test a single batch first to make sure everything works
print("\n🔍 Testing single batch...")
sample_batch = next(iter(opt_train_loader))
print(f"Batch graphs shape: {sample_batch['graphs'].x.shape}")
print(f"Batch labels shape: {sample_batch['labels'].shape}")

# Test forward pass
with torch.no_grad():
    test_graphs = sample_batch['graphs'].to(device)
    test_output = selected_model(test_graphs)
    print(f"Model output shape: {test_output['logits'].shape}")

print("✅ Single batch test successful!")

# Run fast training
try:
    import time
    start_time = time.time()
    
    # Call function with correct parameters
    results = fast_train_model(
        selected_model, 
        opt_train_loader, 
        opt_val_loader, 
        num_epochs=fast_training_config['num_epochs'],
        learning_rate=fast_training_config['learning_rate'],
        device=fast_training_config['device']
    )
    
    total_time = time.time() - start_time
    
    print(f"\n🎉 Fast training completed in {total_time:.1f} seconds!")
    print(f"Final training accuracy: {results['final_train_acc']:.2f}%")
    print(f"Final validation accuracy: {results['final_val_acc']:.2f}%")
    
except Exception as e:
    print(f"❌ Training error: {e}")
    import traceback
    traceback.print_exc()

print(f"\n✅ {selected_model_name} fast training demonstration completed!")

=== Fast Training Demo: BrainGNN ===
Model parameters: 390,979
Training on device: cuda
Fast training configuration: {'num_epochs': 5, 'learning_rate': 0.01, 'device': device(type='cuda')}

🔍 Testing single batch...
Batch graphs shape: torch.Size([3136, 8])
Batch labels shape: torch.Size([8])
Model output shape: torch.Size([8, 2])
✅ Single batch test successful!

🚀 Starting fast training on cuda for 5 epochs...
Training batches: 89, Validation batches: 11
  Batch 20/89: Loss = 0.0000
  Batch 40/89: Loss = 0.0000
  Batch 60/89: Loss = 0.0000
  Batch 80/89: Loss = 0.0000
Epoch 1/5 (87.6s):
  Train Loss: 0.0099, Train Acc: 100.00%
  Val Loss: 0.0000, Val Acc: 100.00%
--------------------------------------------------
  Batch 20/89: Loss = 0.0000
  Batch 40/89: Loss = 0.0000
  Batch 60/89: Loss = 0.0000
  Batch 80/89: Loss = 0.0000
Epoch 2/5 (87.7s):
  Train Loss: 0.0000, Train Acc: 100.00%
  Val Loss: 0.0005, Val Acc: 100.00%
--------------------------------------------------
  Batch 20/8

In [None]:
# Comprehensive Model Comparison and Analysis

# Check dataset class distribution
print("=== Dataset Analysis ===")
all_labels = [dataset[i]['label'].item() for i in range(len(dataset))]
unique_labels, counts = np.unique(all_labels, return_counts=True)
print(f"Class distribution: {dict(zip(unique_labels, counts))}")
total_samples = len(all_labels)
print(f"Class 0: {counts[0] if len(counts) > 0 else 0} samples ({counts[0]/total_samples*100:.1f}%)")
print(f"Class 1: {counts[1] if len(counts) > 1 else 0} samples ({counts[1]/total_samples*100 if len(counts) > 1 else 0:.1f}%)")

# The high accuracy suggests we might have class imbalance
# Let's test all models quickly with 2 epochs each

def quick_test_model(model_name, model, train_loader, val_loader, device='cuda'):
    """Quick test of model performance"""
    print(f"\n🔬 Quick test: {model_name}")
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Single epoch test
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for batch in train_loader:
        graphs = batch['graphs'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        # Handle different model interfaces
        try:
            if hasattr(model, 'encode_sites') and 'sites' in batch:
                outputs = model(graphs, batch['sites'])
            else:
                outputs = model(graphs)
            
            if isinstance(outputs, dict):
                logits = outputs['logits']
            else:
                logits = outputs
                
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        except Exception as e:
            print(f"  ❌ Error in {model_name}: {e}")
            return {
                'status': 'error',
                'error': str(e),
                'train_acc': 0,
                'params': sum(p.numel() for p in model.parameters())
            }
    
    train_acc = 100 * train_correct / train_total
    model_params = sum(p.numel() for p in model.parameters())
    
    print(f"  ✅ {model_name}: {train_acc:.1f}% accuracy, {model_params:,} parameters")
    
    return {
        'status': 'success',
        'train_acc': train_acc,
        'train_loss': train_loss / len(train_loader),
        'params': model_params
    }

# Test all models
print("\n=== Quick Model Testing ===")
model_results = {}

for model_name, model in models.items():
    try:
        result = quick_test_model(model_name, model, opt_train_loader, opt_val_loader, device)
        model_results[model_name] = result
    except Exception as e:
        print(f"❌ Failed to test {model_name}: {e}")
        model_results[model_name] = {'status': 'failed', 'error': str(e)}

# Summary table
print("\n=== Model Performance Summary ===")
print("Model        Status      Train Acc   Parameters  Notes")
print("-" * 65)

for model_name, result in model_results.items():
    if result['status'] == 'success':
        print(f"{model_name:<12} {'✅ Working':<11} {result['train_acc']:>8.1f}% {result['params']:>10,}")
    else:
        print(f"{model_name:<12} {'❌ Error':<11} {'N/A':>8} {'N/A':>10}  {result.get('error', 'Unknown error')[:20]}...")

print(f"\n✅ Model comparison completed!")
print(f"Dataset size: {len(optimized_dataset)} samples")
print(f"Batch size: 8, Training batches: {len(opt_train_loader)}")

=== Dataset Analysis ===
Class distribution: {np.int64(0): np.int64(882)}
Class 0: 882 samples (100.0%)
Class 1: 0 samples (0.0%)

=== Quick Model Testing ===

🔬 Quick test: BrainGNN
