In [None]:
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Data Preprocessing

In [None]:
def preprocess_data(raw_data):
    """
    Preprocess raw multimodal data.
    In practice, you would include steps such as:
      - Artifact removal and filtering
      - Signal normalization and alignment
      - Extraction of time series from regions of interest (ROIs)
    """
    # For demonstration, we assume raw_data is already a NumPy array
    processed_data = raw_data  # Replace with your actual processing pipeline
    return processed_data

# Graph Construction

In [None]:
def construct_graph(processed_data, threshold=0.5):
    """
    Construct a connectivity graph from the processed data.
    
    Parameters:
      processed_data: 2D NumPy array where each row is the time series of a ROI.
      threshold: Correlation threshold to determine the presence of an edge.
      
    Returns:
      G: A NetworkX graph.
      connectivity_matrix: The connectivity matrix computed as correlations.
    """
    # Compute the Pearson correlation matrix as a proxy for connectivity
    connectivity_matrix = np.corrcoef(processed_data)
    
    # Optionally threshold the matrix to retain only stronger connections
    adjacency_matrix = (np.abs(connectivity_matrix) >= threshold).astype(float)
    np.fill_diagonal(adjacency_matrix, 0)
    
    # Construct an undirected graph from the adjacency matrix
    G = nx.from_numpy_array(adjacency_matrix)
    return G, adjacency_matrix

def compute_graph_metrics(G):
    """
    Compute relevant graph metrics that will be used as features.
    
    Returns a dictionary of metrics (e.g. local clustering coefficient and efficiency).
    """
    clustering = nx.clustering(G)  # Local clustering coefficient per node
    global_eff = nx.global_efficiency(G)  # Global efficiency of the graph
    
    metrics = {
        'clustering': clustering,
        'global_efficiency': global_eff
    }
    return metrics

# Graph Neural Network

In [None]:
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        """
        A single GCN layer implementing:
          H_out = σ( A_norm * H_in * W )
        where A_norm is the symmetric normalized adjacency matrix.
        """
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, H, A_norm):
        H = torch.mm(A_norm, H)
        H = self.linear(H)
        return F.relu(H)

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, num_classes):
        """
        A simple 2-layer GCN for node classification.
        """
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(in_features, hidden_features)
        self.gcn2 = GCNLayer(hidden_features, num_classes)
        
    def forward(self, H, A_norm):
        H = self.gcn1(H, A_norm)
        H = self.gcn2(H, A_norm)
        return H

def normalize_adjacency(adjacency_matrix):
    """
    Compute the symmetric normalized adjacency matrix:
      A_norm = D^(-1/2) (A + I) D^(-1/2)
    """
    A = adjacency_matrix + np.eye(adjacency_matrix.shape[0])
    D = np.diag(np.sum(A, axis=1))
    D_inv_sqrt = np.linalg.inv(np.sqrt(D))
    A_norm = np.matmul(np.matmul(D_inv_sqrt, A), D_inv_sqrt)
    return torch.from_numpy(A_norm).float()

# Loss Function

In [None]:
def compute_loss(outputs, labels, class_weights, lateralization_term=0.0):
    """
    Compute the total loss as a sum of weighted cross entropy
    and an additional lateralization penalty.
    
    Parameters:
      outputs: Model predictions (logits).
      labels: True class labels.
      class_weights: Tensor of weights for each class.
      lateralization_term: Additional penalty term (e.g., enforcing symmetry).
    
    Returns:
      loss: Combined loss.
    """
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    loss_ce = criterion(outputs, labels)
    loss = loss_ce + lateralization_term
    return loss

# Training and Evaluation

In [None]:
def train_model(model, optimizer, A_norm, features, labels, num_epochs=100, class_weights=torch.tensor([1.0, 1.0])):
    """
    Train the GCN model.
    
    Parameters:
      model: An instance of the GCN model.
      optimizer: Optimizer (e.g., Adam).
      A_norm: Normalized adjacency matrix (tensor).
      features: Node feature matrix (tensor).
      labels: True labels for each node (tensor).
      num_epochs: Number of training epochs.
      class_weights: Class weighting tensor for handling imbalance.
      
    Returns:
      model: Trained model.
    """
    model.train()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        outputs = model(features, A_norm)
        
        # Here, lateralization_term is set to 0.0 as a placeholder.
        # Replace with your computed lateralization penalty if available.
        lateralization_term = 0.0
        loss = compute_loss(outputs, labels, class_weights, lateralization_term)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs}: Loss = {loss.item():.4f}")
    return model

def evaluate_model(model, A_norm, features, labels):
    """
    Evaluate the trained model and compute accuracy.
    """
    model.eval()
    with torch.no_grad():
        outputs = model(features, A_norm)
        predictions = outputs.argmax(dim=1)
        accuracy = (predictions == labels).float().mean().item()
    return accuracy

# Main

In [None]:
if __name__ == "__main__":
    # Example: Assume raw_data is loaded as a NumPy array (rows: ROIs, columns: time points)
    # For demonstration purposes, we create synthetic data.
    num_rois = 50  # number of brain regions (nodes)
    time_points = 120  # number of time points per region
    raw_data = np.random.rand(num_rois, time_points)
    
    # 1. Preprocess Data
    processed_data = preprocess_data(raw_data)
    
    # 2. Construct Graph
    G, connectivity_matrix = construct_graph(processed_data, threshold=0.6)
    metrics = compute_graph_metrics(G)
    print("Graph Metrics:", metrics)
    
    # 3. Prepare Graph Neural Network inputs
    A_norm = normalize_adjacency(connectivity_matrix)
    
    # Create dummy node features (for example, using the processed data’s statistics)
    # Here we use the mean of each ROI time series as a feature.
    features_np = np.mean(processed_data, axis=1, keepdims=True)
    features = torch.from_numpy(features_np).float()
    
    # Dummy labels for each node (for example, two classes: epileptic focus vs. non-focus)
    labels = torch.randint(0, 2, (num_rois,))
    
    # 4. Instantiate GCN Model
    in_features = features.shape[1]
    hidden_features = 16
    num_classes = 2
    model = GCN(in_features, hidden_features, num_classes)
    
    # 5. Set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    # 6. Train the Model
    model = train_model(model, optimizer, A_norm, features, labels, num_epochs=50)
    
    # 7. Evaluate the Model
    acc = evaluate_model(model, A_norm, features, labels)
    print(f"Model Accuracy: {acc*100:.2f}%")