# GNN Models for Brain Connectivity Analysis

This notebook contains the definitions for the Graph Neural Network models used for ASD classification.

### 1. Import Required Libraries
This cell imports all necessary libraries, including PyTorch and PyTorch Geometric modules.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GraphConv, global_mean_pool, global_max_pool

### 2. Define BrainConnectivityGNN Model
This cell contains the Python class definition for `BrainConnectivityGNN`. This model uses multi-scale convolutions, attention, and residual connections for brain connectivity analysis.

In [None]:
class BrainConnectivityGNN(nn.Module):
    """
    Novel Graph Neural Network specifically designed for brain connectivity analysis.
    
    Innovation Points:
    1. Multi-scale connectivity learning (local + global patterns)
    2. Adaptive pooling based on connectivity strength
    3. Residual connections for deep learning
    4. Attention mechanism for important brain regions
    """
    
    def __init__(self, input_dim=3, hidden_dim=128, num_classes=2):
        super(BrainConnectivityGNN, self).__init__()
        
        # Multi-scale feature extraction
        self.local_conv1 = GCNConv(input_dim, hidden_dim)
        self.local_conv2 = GCNConv(hidden_dim, hidden_dim)
        
        self.global_conv1 = GraphConv(input_dim, hidden_dim)
        self.global_conv2 = GraphConv(hidden_dim, hidden_dim)
        
        # Attention mechanism for important brain regions
        self.attention = GATConv(hidden_dim * 2, hidden_dim, heads=4, concat=False)
        
        # Residual connections
        self.residual_conv = GCNConv(hidden_dim * 2, hidden_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),  # Mean + Max pooling
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Batch normalization layers
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim * 2)
        self.bn4 = nn.BatchNorm1d(hidden_dim)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Multi-scale feature extraction
        # Local connectivity patterns
        local_x1 = F.relu(self.bn1(self.local_conv1(x, edge_index)))
        local_x2 = F.relu(self.bn2(self.local_conv2(local_x1, edge_index)))
        
        # Global connectivity patterns
        global_x1 = F.relu(self.global_conv1(x, edge_index))
        global_x2 = F.relu(self.global_conv2(global_x1, edge_index))
        
        # Combine local and global features
        combined_x = torch.cat([local_x2, global_x2], dim=1)
        combined_x = F.relu(self.bn3(combined_x))
        
        # Attention mechanism for important brain regions
        attention_x = F.relu(self.attention(combined_x, edge_index))
        
        # Residual connection
        residual_x = F.relu(self.bn4(self.residual_conv(combined_x, edge_index)))
        final_x = attention_x + residual_x
        
        # Graph-level pooling (both mean and max for richer representation)
        graph_mean = global_mean_pool(final_x, batch)
        graph_max = global_max_pool(final_x, batch)
        graph_repr = torch.cat([graph_mean, graph_max], dim=1)
        
        # Classification
        output = self.classifier(graph_repr)
        
        return F.log_softmax(output, dim=1)

### 3. Define HierarchicalBrainGNN Model
This cell contains the Python class definition for `HierarchicalBrainGNN`. This model is designed to learn brain connectivity patterns at multiple levels, from individual ROIs to the whole brain.

In [None]:
class HierarchicalBrainGNN(nn.Module):
    """
    Hierarchical GNN that learns brain connectivity at multiple levels:
    1. Individual ROI level
    2. Network level (groups of ROIs)
    3. Whole brain level
    """
    
    def __init__(self, input_dim=3, hidden_dim=64, num_classes=2):
        super(HierarchicalBrainGNN, self).__init__()
        
        # Level 1: Individual ROI processing
        self.roi_conv1 = GCNConv(input_dim, hidden_dim)
        self.roi_conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # Level 2: Network-level processing (attention-based)
        self.network_attention = GATConv(hidden_dim, hidden_dim, heads=8, concat=False)
        
        # Level 3: Whole-brain integration
        self.brain_conv = GCNConv(hidden_dim, hidden_dim)
        
        # Final classification
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Level 1: ROI-level processing
        roi_features = F.relu(self.roi_conv1(x, edge_index))
        roi_features = F.dropout(roi_features, training=self.training)
        roi_features = F.relu(self.roi_conv2(roi_features, edge_index))
        
        # Level 2: Network-level attention
        network_features = F.relu(self.network_attention(roi_features, edge_index))
        
        # Level 3: Whole-brain integration
        brain_features = F.relu(self.brain_conv(network_features, edge_index))
        
        # Graph-level representation
        graph_repr = global_mean_pool(brain_features, batch)
        
        # Classification
        output = self.classifier(graph_repr)
        
        return F.log_softmax(output, dim=1)

### 4. Instantiate and Summarize Models
This cell demonstrates how to create instances of both GNN models. It sets the device, defines sample input dimensions, and prints a summary of each model.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define sample input dimensions (matching the fMRI data features)
input_dim = 3
num_classes = 2

# Instantiate the BrainConnectivityGNN model
model_brain = BrainConnectivityGNN(
    input_dim=input_dim,
    hidden_dim=128,
    num_classes=num_classes
).to(device)

# Instantiate the HierarchicalBrainGNN model
model_hierarchical = HierarchicalBrainGNN(
    input_dim=input_dim,
    hidden_dim=64,
    num_classes=num_classes
).to(device)

print(f"\nModels Created:")
print(f"   BrainConnectivityGNN: {sum(p.numel() for p in model_brain.parameters()):,} parameters")
print(f"   HierarchicalBrainGNN: {sum(p.numel() for p in model_hierarchical.parameters()):,} parameters")
print(f"   Input features per ROI: {input_dim}")