# WoMAD Development Notebook

### Setup and Configurations

In [2]:
# Install basic dependencies
!pip install --quiet requests nilearn nibabel brainspace numpy pandas scikit-learn torch torch-geometric scipy rich>=13.5.2
!pip install --quiet optuna

# Install torch_geometric and its dependencies
!pip install --quiet torch_geometric torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

import os
import io
import time
import zipfile
from typing import List, Dict, Tuple, Any, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as PyG_Data
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import nibabel as nib

import shap   # For result interp module
import pytest # For TESTS module
import optuna # For hyperparameter module

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.9/413.9 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch_scatter (setup.py) ... [?25l[?25hdone
  Building wheel for torch_sparse (setup.py) ... [?25l[?25hdone
  Building wheel for torch_cluster (setup.py) ... [?25l[?25hdone
  Building wheel for torch_spline_conv (setup.py) ... [?25l[?25hdone


In [4]:
# Project Paths:
import os
from google.colab import drive; drive.mount('/content/drive') # Temporary for Colab
project_root = "/content/drive/MyDrive/WoMAD/WoMAD/Notebooks/data" # Google Colab-specific (different in the .py files and the repo)

unprocessed_path = os.path.join(project_root, "HCP_zipped")
processed_path   = os.path.join(project_root, "processed")
model_ready_path = os.path.join(project_root, "model_ready")

sub_list_txt_path = os.path.join(project_root, "full_3T_task_subjects.txt")

# WoMAD-specific variables:
target_tasks = ["WM", "EMOTION", "LANGUAGE"]

target_subtasks = {
    "WM"      : ["0bk_body", "0bk_faces", "0bk_places", "0bk_tools",
                 "2bk_body", "2bk_faces", "2bk_places", "2bk_tools"],
    "EMOTION" : ["fear", "neut"],
    "LANGUAGE": ["math", "story"],
}

TR = 0.72

rest_tasks    = ["REST1", "REST2"]
run_direction = ["LR"   , "RL"]

# Subjects with full 3T imaging protocol completed:
full_3T_task_subjects = []

with open(sub_list_txt_path, "r") as file:
    raw_list = file.read()
    str_list = raw_list.strip().split(",")
    num_list = [int(subID.strip()) for subID in str_list if subID.strip()]

full_3T_task_subjects = num_list

# Model config dictionaries
lstm_config = {
    "hidden_size" : 128,
    "num_layers"  :   2,
    "dropout"     : 0.2
}

fusion_config = {
    "total_input_feats" : 736, # Temporary: 360 parcels, 128 from the LSTM, and 248 guessing from the 4D network
    "hidden_size"       : 128
}

training_loss_weights = {
    "overall_loss_weight" : 0.5,
    "node_loss_weight"    : 0.5
}

# Randomly selected subject list for the pilot:
pilot_subjects = [283543, 180937, 379657, 145632, 100206, 270332, 707749, 454140, 194847, 185038]

# Temporary variables for development:
dev_subjects = [100206, 100408]

Mounted at /content/drive


In [5]:
#GCA configurations
gca_config = {
    "max_lag": 1,
    "significance_level": 0.05,
    "pairwise": True
}

#HMM configurations
hmm_config = {
    "n_states": 3,
    "n_iter": 100,
    "covariance_type": "full"
}


### Data Module

In [6]:
# File Access:
def generate_paths(task : str = target_tasks[0],
                   run  : str = run_direction[0],
                   sub_list : list = dev_subjects):
    """
    Uses the paths configured in WoMAD_config to create subject-specific paths.

    Arguments:
        task             (str): The target task from HCP -> [W]orking [M]emory as default
        run              (str): RL or LR -> RL as our arbitrary default
        sub_list (list of int): List of target subject ID's as defined in the config file

    Returns:
        Dictionary with this integer and tuple-of-strings format:
        paths = {
            subject_id (int) : ("Main 'Results' file path", "EV file path")
        }
    """
    paths = {}

    # General path format for each subject's directory:
    # (f"../data/HCP_zipped/{subject-ID}/MNINonLinear/Results/")

    # General path format for subjects' task EV files:
    # (f"../data/HCP_zipped/{subject-ID}/MNINonLinear/Results/tfMRI_{TASK}_{RUN}/EVs/")

    # List of target subjects: full_3T_task_subjects (imported from WoMAD_config)
    for subject in sub_list:
        subject_path    = f"../data/HCP_zipped/{subject}/MNINonLinear/Results/"
        subject_ev_path = f"../data/HCP_zipped/{subject}/MNINonLinear/Results/tfMRI_{task}_{run}/EVs/"
        paths[subject]  = (subject_path, subject_ev_path)

    return paths


def load_data_from_path(task : str = target_tasks[0],
                        run  : str = run_direction[0],
                        subject : str = dev_subjects[0],
                        subtask : str = target_subtasks["WM"][0]):
    """
    Reads the contents of each subject's files.

    Arguments:
        task    (str): The target task from HCP -> [W]orking [M]emory as default
        run     (str): RL or LR -> RL as our arbitrary default
        subject (int): ID of specific target subject
        subtask (str): The target subtask in string format -> Example: "0bk_tools"

    Returns:
        Dictionary of {Subject: (Tuple of fMRI data)} and
        EV file contents assigned to the ev_file variable.
    """
    try:
        paths = generate_paths(task = task, run = run)
        bold_ts_path = paths[subject][0] + f"tfMRI_{task}_{run}/tfMRI_{task}_{run}_Atlas_MSMAll_hp0_clean_rclean_tclean.dtseries.nii"
        ev_file_path = paths[subject][1] + f"{subtask}.txt"

        with open(ev_file_path, "r") as ev:
            ev_file = ev.read()

        bold_ts = nib.load(bold_ts_path)
        bold_data = bold_ts.get_fdata()
        bold_header = bold_ts.header

        fmri_timeseries = {subject: (bold_ts, bold_header, bold_data)}

        return fmri_timeseries, ev_file

    except Exception as e:
        print("Error loading time series and EV files from path!")


# Preprocessing:
## Parse and Isolate Trials
def isolate_trials(fmri_ts, ev_file, TR : float = 0.72):
    """
    Parses through the data and isolates each task trial using EV files.

    Input: The fMRI dictionary and EV file from load_data_from_path() function.

    Returns:
        List of trials isolated using the ev_file.
    """
    trial_list = []

    for subject, (bold_ts, bold_header, bold_data) in fmri_ts.items():
        data_array = bold_data

        try:
            ev_data = np.loadtxt(io.StringIO(ev_file))
        except ValueError:
            print(f"Could not parse EV file for subject {subject}.")
            continue

        for onset, duration, _ in ev_data:
            start_idx = int(np.floor(onset / TR))
            end_idx = int(np.ceil((onset + duration) / TR))
            trial_data = data_array[:, start_idx:end_idx]

            trial_list.append({
                "subject" : subject,
                "onset" : onset,
                "duration" : duration,
                "data" : trial_data
            })

    return trial_list

## Normalization
def normalize_data(data : np.ndarray, norm_mode: str = "z_score"):
    """
    Normalizes a numpy array of fMRI time series data.

    Arguments:
        data (np.ndarray): The time series data with shape (voxels, time_points)
        norm_mode   (str): Method of normalization (Z score, min/max, etc.)

    Returns:
        Numpy array of normalized data.
    """
    data = np.array(data)

    if norm_mode == "z_score":
        ts_data_mean = np.mean(data, axis = 1, keepdims = True)
        ts_data_stdv = np.std(data , axis = 1, keepdims = True)

        ts_data_stdv[ts_data_stdv == 0] = 1.0

        normalized_ts_data = (data - ts_data_mean) / ts_data_stdv

        return normalized_ts_data

    elif norm_mode == "min_max":
        min_ts_data = np.min(data, axis = 1, keepdims = True)
        max_ts_data = np.max(data, axis = 1, keepdims = True)

        range_ts_data = max_ts_data - min_ts_data
        range_ts_data[range_ts_data == 0] = 1.0

        normalized_ts_data = (data - min_ts_data) / range_ts_data

        return normalized_ts_data

    else: # For now ...
        print(f"Normalization mode '{norm_mode}' not defined.\nReturning data as is.")
        return data

## Save to Pandas DataFrame
def save_to_df(trial_list : List[Dict[str, Any]],
               file_name : str,
               output_dir : str = processed_path):
    """
    Converts the list of isolated trials to a Pandas DF and saves it to defined path.

    Arguments:
        trial_list (list): List of {"subject", "onset", "duration", "data"} dictionaries.
        file_name   (str): Name of output file.
        output_dir  (str): Directory for saving the output file.

    Saves the pd.DataFrame to output_dir.
    """
    df_from_trial_ts = pd.DataFrame(trial_list)

    os.makedirs(output_dir, exist_ok = True)
    save_path = os.path.join(output_dir, f"{file_name}.pkl")

    df_from_trial_ts.to_pickle(save_path)

    print(f"Data saved to {save_path}")

    return df_from_trial_ts


# Initial Processing (with the WoMAD_data class):
class WoMAD_data(Dataset):
    def __init__(self,
                 task : str,
                 runs : list = run_direction,
                 subjects : list = dev_subjects,
                 output_dir : str = processed_path):
        """
        Basic configuration of the dataset.

        Arguments:
            task (str): The target task ("WM")
            runs (list): Run directions
            subjects (list): List of target subject IDs
            output_dir (str): Directory for saving processed data
        """
        self.task = task
        self.runs = runs
        self.subjects = subjects
        self.output_dir = output_dir

        self.data = []

    def __len__(self) -> int:
        """
        Returns the total number of subjects in the dataset.
        """
        return len(self.data_paths)

    def __getitem__(self, indx: int):
        """
        Loads one trial's data and returns:
            - input timeseries (X)
            - overall target (Y_overall)
            - node target (Y_node)
        """
        trial = self.data[indx]

        # Input data with shape (target_nodes, timepoints)
        data_np = trial["data"]
        data_tensor = torch.from_numpy(data_np).float()

        # Overall target:
        # TODO: Add overall target components.

        # Placeholder:
        overall_target_np = trial["stats"]["trial_mean"]
        overall_target_tensor = torch.tensor([overall_target_np]).float()

        # Node target with shape (target_nodes)
        node_target_np = trial["stats"]["mean_per_node"]
        node_target_tensor = torch.from_numpy(node_target_np).float()

        return data_tensor, overall_target_tensor, node_target_tensor

    def _load_data(self):
        """
        Load and parse the data using NIfTI and EV files.
        """
        parsed_data = []

        for subject in self.subject:
            for run in self.runs:
                paths = generate_paths(task = self.task, run = run,
                                       sub_list = [subject])

                subtasks = target_subtasks.get(self.task, [])

                for subtask in subtasks:
                    fmri_ts, ev_file = load_data_from_path(task = self.task,
                                                           run = run,
                                                           subject = subject,
                                                           subtask = subtask)
                    trial_list_subtask = isolate_trials(fmri_ts, ev_file)

                    for trial_dict in trial_list_subtask:
                        trial_dict["run"] = run
                        trial_dict["subtask"] = subtask
                        parsed_data.append(trial_dict)

        # TODO: Error handling inside the for loop.

        self.data = parsed_data

        return self.data

    def basic_processing(self, norm_mode : str = "z_score",
                         file_name : str = "processed_fMRI_data"):
        """
        Normalization and saving with normalize_data() and save_to_df().
        """
        processed_trials = []

        for trial in self.data:
            normalized_trial = normalize_data(trial["data"], norm_mode = norm_mode)

            trial["data"] = normalized_trial
            trial["norm_mode"] = norm_mode
            processed_trials.append(trial)

        file_to_save_processed_data = f"{file_name}_{self.task}_{norm_mode}"
        self.processed_df = save_to_df(processed_trials,
                                       file_to_save_processed_data,
                                       self.output_dir)

        self.data = processed_trials

        return self.processed_df

    def _calc_corr_matrix(self, trial_data: np.ndarray) -> dict:
        """
        Calculates whole-brain and network-level correlation matrices for a single isolated trial.

        NOTE:   Network-level correlations require network masks
                or specific voxel/parcel definitions.
                These have not yet been defined in the config files.

        Arguments:
            trial_data (np.ndarray): The isolated trial's time series data.

        Returns:
            A dictionary containing the calculated correlation matrix.
        """
        # All voxels (whole_brain)
        whole_brain_corr_mat = np.corrcoef(trial_data)

        whole_brain_corr_mat[np.isnan(whole_brain_corr_mat)] = 0

        # TODO: Create network-level correlation

        return {
            "whole_brain"  : whole_brain_corr_mat,
            "network_level": whole_brain_corr_mat   # Temp solution until the network-level is defined.
        }

    def calc_func_connectivity(self):
        """
        Calculates the correlation matrices for all isolated trials.

        Returns:
            The updated list of trial dictionaries with the correlation matrices added.
        """
        for trial in self.data:
            trial_ts_data = trial.get("data") # Should be normalized

            # Correlation matrix
            trial_corr_mat = self._calc_corr_matrix(trial_ts_data)

            trial["corr_matrix"] = trial_corr_mat

        return self.data

    def calc_basic_stats(self):
        """
        Calculate basic statistics (mean and std) of
        the activity for isolated trials across all nodes/voxels.

        Returns:
            Updated list of trial dictionaries with "stats" key added.
        """
        for trial in self.data:
            trial_ts_data = trial.get("data")

            if trial_ts_data is None:
                continue

            # Mean activity
            node_mean_activity = np.mean(trial_ts_data, axis = 1)

            # Standard deviation
            node_std_activity = np.std(trial_ts_data, axis = 1)

            # Overall average for entire trial
            trial_mean_activity = np.mean(node_mean_activity)
            trial_mean_of_std   = np.mean(node_std_activity)

            trial["stats"] = {
                "mean_per_node": node_mean_activity,
                "std_per_node" : node_std_activity,
                "trial_mean"   : trial_mean_activity,
                "overall_std"  : trial_mean_of_std
            }

        return self.data

    def visualize_correlations(self,
                               target_trial_idx: int = 0,
                               matrix_type: str = "whole_brain"):
        """
        Generate a heatmap visualization for one trial's correlation matrix.

        Arguments:
            target_trial_idx (int): Index of the trial in self.data
            matrix_type      (str): The type of matrix to plot
                                    "whole_brain" or "network_level"
        """
        trial = self.data[target_trial_idx]
        corr_matrices = trial.get("corr_matrix")

        matrix_to_plot = corr_matrices.get(matrix_type)

        plt.figure(figsize = (10, 8))
        sns.heatmap(matrix_to_plot,
                    cmap = "RdBu_r",
                    vmin = -1, vmax = 1,
                    square = True,
                    cbar_kws = {"label" : "Pearson Correlation Coefficient"})

        subject_id = trial.get("subject", "N/A")

        task = trial.get("task", "N/A")
        subtask = trial.get("subtask", "N/A")

        plt.title(f"{matrix_type.title()} Correlation Matrix\nSubject: {subject_id}, {task}: {subtask}, Trial {target_trial_idx}")
        plt.xlabel("Node/Voxel Index")
        plt.ylabel("Node/Voxel Index")

        plt.show()

# TODO: Add function for "validation set processing" which can process non-HCP data.

### Model Setup Module

### The Dynamic Input module

In [7]:
# Dynamic Adapter
TARGET_NODE_COUNT = 360       # 360 parcels based on HCP and Glasser parcellation

class DynamicInput(nn.Module):
    """
    This module handles input data with different number of voxels
    and adapts it for the modules (Info flow or Core) of the WoMAD model.
    """
    def __init__(self, target_nodes: int = TARGET_NODE_COUNT):
        super().__init__()
        self.target_nodes = target_nodes

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Dynamically adapts the input data to defined dimension.

        Input:
            x : 3D tensor with shape (batch, current_nodes, timepoints)

        Output:
            The dynamically adapted 3D tensor with shape (batch, target_nodes, timepoints)
        """
        batch, current_nodes, timepoints = x.shape

        if current_nodes == self.target_nodes:
            return x

        elif current_nodes > self.target_nodes:
            # TODO: Add voxel-to-parcel downsampling (requires a parcellation map)

            # Current solution for downsampling: Adaptive Pooling (linear projection)
            x_reshaped = x.transpose(1, 2)

            x_pooled = F.adaptive_avg_pool1d(x_reshaped, self.target_nodes)

            x_pooled_out = x_pooled.transpose(1, 2)

            return x_pooled_out

        else: # current_nodes < self.target_nodes
            required_padding = self.target_nodes - current_nodes

            x_padded = F.pad(x, (0, 0, 0, required_padding))

            return x_padded

#### Info flow submodule


In [8]:
#GCA
from scipy import stats
from typing import Tuple, Dict, Optional, Any
import warnings


class GrangerCausalityAnalysis:
    def __init__(self, config: dict):
        """
        Sets up the complete WoMAD model with all modules and submodules.
        (Each module and submodule includes a dynamic input layer that matches the size of input.)
        """
        # Information Flow Module
        ## Effective Connectivity: GCA
        ## Dynamic Functional Connectivity: HMM
        ## Final info-flow manifold: Temporal GNN

    def __init__(self,
                 max_lag: int = 1,
                 significance_level: float = 0.05):
        """
        Initialize GCA analysis.

        Args:
            max_lag: Maximum lag for VAR model (default=1 for fMRI)
            significance_level: P-value threshold for significance testing
        """
        self.max_lag = max_lag
        self.significance_level = significance_level
        self.adjacency_matrix = None
        self.f_statistics = None
        self.p_values = None

    def _create_lagged_data(self,
                            data: np.ndarray,
                            lag: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Create lagged versions of time series for VAR modeling.

        Args:
            data: Time series (n_nodes, n_timepoints)
            lag: Number of time lags

        Returns:
            X: Lagged predictors (n_samples, n_nodes * lag)
            Y: Target values (n_samples, n_nodes)
        """
        n_nodes, n_timepoints = data.shape
        n_samples = n_timepoints - lag

        # Target: values at time t
        Y = data[:, lag:].T  # (n_samples, n_nodes)

        # Predictors: values at time t-1, t-2, ..., t-lag
        X_list = []
        for l in range(1, lag + 1):
            X_list.append(data[:, lag-l:n_timepoints-l].T)
        X = np.hstack(X_list)  # (n_samples, n_nodes * lag)

        return X, Y

    def _fit_var_model(self,
                       X: np.ndarray,
                       y: np.ndarray) -> Tuple[np.ndarray, float]:
        """
        Fit VAR model using OLS regression.

        Args:
            X: Predictor matrix
            y: Target vector

        Returns:
            coefficients: Regression coefficients
            residual_variance: Variance of residuals
        """
        # Add intercept
        X_with_intercept = np.column_stack([np.ones(X.shape[0]), X])

        # OLS solution: beta = (X'X)^(-1) X'y
        try:
            XtX_inv = np.linalg.pinv(X_with_intercept.T @ X_with_intercept)
            coefficients = XtX_inv @ X_with_intercept.T @ y

            # Calculate residuals
            y_pred = X_with_intercept @ coefficients
            residuals = y - y_pred
            residual_variance = np.var(residuals)

            return coefficients, residual_variance
        except np.linalg.LinAlgError:
            return None, np.inf

    def _granger_causality_test(self,
                                X_full: np.ndarray,
                                X_reduced: np.ndarray,
                                y: np.ndarray) -> Tuple[float, float]:
        """
        Perform F-test for Granger causality.

        Args:
            X_full: Full model predictors (includes candidate cause)
            X_reduced: Reduced model predictors (excludes candidate cause)
            y: Target time series

        Returns:
            f_statistic: F-test statistic
            p_value: P-value of the test
        """
        n_samples = len(y)

        # Fit full model
        _, rss_full = self._fit_var_model(X_full, y)

        # Fit reduced model
        _, rss_reduced = self._fit_var_model(X_reduced, y)

        if rss_full == np.inf or rss_reduced == np.inf:
            return 0.0, 1.0

        # Degrees of freedom
        df_full = X_full.shape[1] + 1  # +1 for intercept
        df_reduced = X_reduced.shape[1] + 1
        df_diff = df_full - df_reduced
        df_residual = n_samples - df_full

        if df_residual <= 0 or df_diff <= 0:
            return 0.0, 1.0

        # F-statistic
        if rss_full > 0:
            f_stat = ((rss_reduced - rss_full) / df_diff) / (rss_full / df_residual)
            f_stat = max(0, f_stat)  # Ensure non-negative
        else:
            f_stat = 0.0

        # P-value
        p_value = 1 - stats.f.cdf(f_stat, df_diff, df_residual)

        return f_stat, p_value

    def compute_connectivity(self,
                            data: np.ndarray,
                            verbose: bool = False) -> Dict[str, np.ndarray]:
        """
        Compute pairwise Granger causality for all node pairs.

        Args:
            data: fMRI time series (n_nodes, n_timepoints)
            verbose: Print progress

        Returns:
            Dict with adjacency matrix, F-statistics, and p-values
        """
        n_nodes, n_timepoints = data.shape

        # Initialize output matrices
        self.f_statistics = np.zeros((n_nodes, n_nodes))
        self.p_values = np.ones((n_nodes, n_nodes))

        # Create lagged data
        X_all, Y = self._create_lagged_data(data, self.max_lag)

        # Test each directed pair
        for target in range(n_nodes):
            y = Y[:, target]

            for source in range(n_nodes):
                if source == target:
                    continue

                # Full model: all nodes predict target
                X_full = X_all

                # Reduced model: exclude source node
                source_cols = [source + i * n_nodes for i in range(self.max_lag)]
                other_cols = [c for c in range(X_all.shape[1]) if c not in source_cols]
                X_reduced = X_all[:, other_cols]

                # Granger causality test
                f_stat, p_val = self._granger_causality_test(X_full, X_reduced, y)

                self.f_statistics[source, target] = f_stat
                self.p_values[source, target] = p_val

        # Create binary adjacency matrix based on significance
        self.adjacency_matrix = (self.p_values < self.significance_level).astype(float)
        np.fill_diagonal(self.adjacency_matrix, 0)

        return {
            'adjacency': self.adjacency_matrix,
            'f_statistics': self.f_statistics,
            'p_values': self.p_values
        }

    def get_hub_nodes(self, top_k: int = 10) -> Dict[str, List[int]]:
        """
        Identify hub nodes based on connectivity.

        Args:
            top_k: Number of top hubs to return

        Returns:
            Dict with hub nodes for outgoing and incoming connections
        """
        if self.adjacency_matrix is None:
            raise ValueError("Run compute_connectivity first")

        out_degree = np.sum(self.adjacency_matrix, axis=1)
        in_degree = np.sum(self.adjacency_matrix, axis=0)

        top_k = min(top_k, len(out_degree))

        return {
            'hub_nodes_out': np.argsort(out_degree)[-top_k:][::-1].tolist(),
            'hub_nodes_in': np.argsort(in_degree)[-top_k:][::-1].tolist(),
            'out_degree': out_degree,
            'in_degree': in_degree
        }


def compute_gca_for_trial(trial_data: np.ndarray,
                          max_lag: int = 1,
                          significance_level: float = 0.05) -> Dict:
    """
    Convenience function to compute GCA for a single trial.

    Args:
        trial_data: fMRI data (n_nodes, n_timepoints)
        max_lag: VAR model lag
        significance_level: P-value threshold

    Returns:
        Dict with GCA results and summary statistics
    """
    gca = GrangerCausalityAnalysis(max_lag=max_lag,
                                   significance_level=significance_level)
    results = gca.compute_connectivity(trial_data)
    hub_info = gca.get_hub_nodes()

    n_significant = int(np.sum(results['adjacency']))
    n_possible = results['adjacency'].shape[0] * (results['adjacency'].shape[0] - 1)

    return {
        'adjacency': results['adjacency'],
        'f_statistics': results['f_statistics'],
        'p_values': results['p_values'],
        'summary': {
            'n_significant': n_significant,
            'density': n_significant / n_possible if n_possible > 0 else 0,
            'hub_nodes_out': hub_info['hub_nodes_out'],
            'hub_nodes_in': hub_info['hub_nodes_in']
        }
    }

In [10]:
#HMM
!pip install --quiet hmmlearn
from hmmlearn import hmm

from sklearn.decomposition import PCA

def run_hmm_stating(trial_data: np.ndarray,
                    n_states: int = 3,
                    n_iter: int = 100,
                    n_pca_components: int = None) -> Dict:
    """
    Identifies hidden brain states using Hidden Markov Model.

    Args:
        trial_data (np.ndarray): fMRI data (n_nodes, n_timepoints).
        n_states: Number of hidden states to detect.
        n_iter: Number of EM iterations.
        n_pca_components: Number of principal components to use for PCA.

    Returns:
        Dict with state_sequence, state_means, transition_matrix
    """

    #HMM expects (n_samples, n_features) = (timepoints, nodes)
    X = trial_data.T
    n_samples, n_features = X.shape

    # Use PCA if features > samples to avoid degenerate solution
    pca_used = False
    if n_pca_components is None:
        # Keep n_features such that model parameters < data points
        max_components = min(n_samples - 1, n_features, 20)
        if n_features > max_components:
            n_pca_components = max_components

    if n_pca_components and n_pca_components < n_features:
        pca = PCA(n_components=n_pca_components, random_state=42)
        X = pca.fit_transform(X)
        pca_used = True

    # Fit HMM with diagonal covariance for efficiency
    model = hmm.GaussianHMM(
        n_components=n_states,
        covariance_type="diag",
        n_iter=n_iter,
        random_state=42
    )

    model = hmm.GaussianHMM(n_components=n_states, covariance_type="full", n_iter=100)

    model.fit(X)

    state_sequence = model.predict(X)

    return {
        'state_sequence': state_sequence,
        'state_means': model.means_,
        'transition_matrix': model.transmat_,
        'n_states': n_states,
        'pca_used': pca_used,
        'n_features_used': X.shape[1]
    }

def compute_state_connectivity(trial_data: np.ndarray,
                               state_sequence: np.ndarray) -> Dict[int, np.ndarray]:
    """
    Compute functional connectivity matrix for each HMM state.

    Args:
        trial_data: fMRI data (n_nodes, n_timepoints)
        state_sequence: HMM state labels for each timepoint

    Returns:
        Dict mapping state_id -> correlation matrix
    """
    unique_states = np.unique(state_sequence)
    state_connectivity = {}

    for state in unique_states:
        state_mask = state_sequence == state
        if np.sum(state_mask) < 2:
            continue

        # Extract timepoints for this state
        state_data = trial_data[:, state_mask]

        # Compute correlation matrix
        corr_matrix = np.corrcoef(state_data)
        corr_matrix[np.isnan(corr_matrix)] = 0

        state_connectivity[int(state)] = corr_matrix

    return state_connectivity

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/166.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.0/166.0 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [11]:
#Topological Similarity & Graph Overlap Analysis
class TopologicalSimilarity:
    """
    Compute and compare topological properties of brain connectivity graphs.

    This implements Step 5 of the Information Flow Module:
    Compare graph architecture using key topological metrics.

    Metrics computed:
    - Degree distribution
    - Clustering coefficient
    - Global/Local efficiency
    - Modularity
    """

    def __init__(self, n_nodes: int = 360):
        self.n_nodes = n_nodes

    def degree_distribution(self, adj: np.ndarray) -> Dict[str, Any]:
        """Compute degree distribution statistics."""
        # Binarize
        binary_adj = (adj != 0).astype(int)
        np.fill_diagonal(binary_adj, 0)

        # Degree (sum of connections per node)
        degree = np.sum(binary_adj, axis=1)

        return {
            'degrees': degree,
            'mean_degree': float(np.mean(degree)),
            'std_degree': float(np.std(degree)),
            'max_degree': int(np.max(degree)),
            'min_degree': int(np.min(degree))
        }

    def clustering_coefficient(self, adj: np.ndarray) -> Dict[str, Any]:
        """Compute clustering coefficient for each node."""
        binary_adj = (adj != 0).astype(int)
        np.fill_diagonal(binary_adj, 0)

        n = binary_adj.shape[0]
        clustering = np.zeros(n)

        for i in range(n):
            neighbors = np.where(binary_adj[i] > 0)[0]
            k = len(neighbors)

            if k < 2:
                clustering[i] = 0
                continue

            # Count edges between neighbors
            subgraph = binary_adj[np.ix_(neighbors, neighbors)]
            n_triangles = np.sum(subgraph) / 2
            max_triangles = k * (k - 1) / 2

            clustering[i] = n_triangles / max_triangles if max_triangles > 0 else 0

        return {
            'node_clustering': clustering,
            'mean_clustering': float(np.mean(clustering)),
            'std_clustering': float(np.std(clustering))
        }

    def global_efficiency(self, adj: np.ndarray) -> float:
        """Compute global efficiency of the network."""
        binary_adj = (adj != 0).astype(int)
        np.fill_diagonal(binary_adj, 0)

        n = binary_adj.shape[0]

        # Floyd-Warshall for shortest paths
        dist = np.where(binary_adj > 0, 1, np.inf)
        np.fill_diagonal(dist, 0)

        for k in range(n):
            for i in range(n):
                for j in range(n):
                    if dist[i, k] + dist[k, j] < dist[i, j]:
                        dist[i, j] = dist[i, k] + dist[k, j]

        # Efficiency = 1/distance (excluding self-connections)
        with np.errstate(divide='ignore'):
            inv_dist = 1.0 / dist
        inv_dist[np.isinf(inv_dist)] = 0
        np.fill_diagonal(inv_dist, 0)

        return float(np.sum(inv_dist) / (n * (n - 1)))

    def modularity(self, adj: np.ndarray, n_communities: int = None) -> Dict[str, Any]:
        """Compute modularity using spectral clustering."""
        from sklearn.cluster import SpectralClustering

        binary_adj = (adj != 0).astype(float)
        np.fill_diagonal(binary_adj, 0)

        # Make symmetric
        symmetric_adj = (binary_adj + binary_adj.T) / 2

        # Determine number of communities
        if n_communities is None:
            n_communities = min(5, max(2, int(np.sqrt(adj.shape[0] / 2))))

        # Spectral clustering
        try:
            clustering = SpectralClustering(
                n_clusters=n_communities,
                affinity='precomputed',
                random_state=42,
                assign_labels='kmeans'
            )
            labels = clustering.fit_predict(symmetric_adj + 0.01)
        except:
            labels = np.zeros(adj.shape[0], dtype=int)

        # Calculate modularity Q
        m = np.sum(symmetric_adj) / 2
        if m == 0:
            return {'modularity': 0.0, 'community_labels': labels, 'n_communities': n_communities}

        k = np.sum(symmetric_adj, axis=1)
        Q = 0
        for i in range(adj.shape[0]):
            for j in range(adj.shape[0]):
                if labels[i] == labels[j]:
                    Q += symmetric_adj[i, j] - (k[i] * k[j]) / (2 * m)
        Q /= (2 * m)

        return {
            'modularity': float(Q),
            'community_labels': labels,
            'n_communities': len(np.unique(labels))
        }


class GraphOverlapAnalysis:
    """
    Analyze overlap between connectivity graphs from different tasks.

    This implements Step 4 of the Information Flow Module:
    Compare graphs to quantify WM involvement in non-WM tasks.

    Metrics:
    - Jaccard Index (edge-wise overlap)
    - Pearson correlation (weighted similarity)
    - WM involvement score
    """

    def __init__(self, n_nodes: int = 360):
        self.n_nodes = n_nodes

    def jaccard_index(self,
                      adj1: np.ndarray,
                      adj2: np.ndarray,
                      threshold_percentile: float = None) -> Dict[str, float]:
        """
        Compute Jaccard Index for edge-wise overlap.

        Jaccard = |A ∩ B| / |A ∪ B|
        """
        # Binarize
        if threshold_percentile and np.any(adj1 > 0) and np.any(adj2 > 0):
            thresh1 = np.percentile(adj1[adj1 > 0], threshold_percentile)
            thresh2 = np.percentile(adj2[adj2 > 0], threshold_percentile)
            binary1 = (adj1 >= thresh1).astype(int)
            binary2 = (adj2 >= thresh2).astype(int)
        else:
            binary1 = (adj1 != 0).astype(int)
            binary2 = (adj2 != 0).astype(int)

        np.fill_diagonal(binary1, 0)
        np.fill_diagonal(binary2, 0)

        # Jaccard calculation
        intersection = np.sum(binary1 & binary2)
        union = np.sum(binary1 | binary2)

        jaccard = intersection / union if union > 0 else 0

        return {
            'jaccard_index': float(jaccard),
            'intersection_edges': int(intersection),
            'union_edges': int(union),
            'edges_in_graph1': int(np.sum(binary1)),
            'edges_in_graph2': int(np.sum(binary2))
        }

    def pearson_correlation(self,
                           adj1: np.ndarray,
                           adj2: np.ndarray) -> Dict[str, float]:
        """Compute Pearson correlation between flattened adjacency matrices."""
        # Get upper triangle (excluding diagonal)
        mask = np.triu(np.ones_like(adj1, dtype=bool), k=1)

        vec1 = adj1[mask]
        vec2 = adj2[mask]

        if np.std(vec1) == 0 or np.std(vec2) == 0:
            return {'pearson_r': 0.0, 'p_value': 1.0}

        r, p = stats.pearsonr(vec1, vec2)

        return {
            'pearson_r': float(r),
            'p_value': float(p)
        }

    def compute_overlap(self,
                        adj1: np.ndarray,
                        adj2: np.ndarray,
                        threshold_percentile: float = 10.0) -> Dict[str, Any]:
        """Compute all overlap metrics between two graphs."""
        jaccard = self.jaccard_index(adj1, adj2, threshold_percentile)
        correlation = self.pearson_correlation(adj1, adj2)

        return {
            'jaccard': jaccard,
            'correlation': correlation
        }

    def wm_involvement_score(self,
                             wm_adjacency: np.ndarray,
                             task_adjacency: np.ndarray) -> Dict[str, float]:
        """
        Calculate WM involvement score for a non-WM task.

        Score = fraction of WM network edges present in task network
        """
        wm_binary = (wm_adjacency != 0).astype(int)
        task_binary = (task_adjacency != 0).astype(int)

        np.fill_diagonal(wm_binary, 0)
        np.fill_diagonal(task_binary, 0)

        wm_edges = np.sum(wm_binary)
        overlap = np.sum(wm_binary & task_binary)

        score = overlap / wm_edges if wm_edges > 0 else 0

        return {
            'wm_involvement_score': float(score),
            'wm_edges_in_task': int(overlap),
            'total_wm_edges': int(wm_edges),
            'total_task_edges': int(np.sum(task_binary))
        }

    def compare_multiple_tasks(self,
                               task_adjacencies: Dict[str, np.ndarray]) -> 'pd.DataFrame':
        """Compare all pairs of task graphs."""
        import pandas as pd

        tasks = list(task_adjacencies.keys())
        results = []

        for i, task1 in enumerate(tasks):
            for task2 in tasks[i+1:]:
                overlap = self.compute_overlap(
                    task_adjacencies[task1],
                    task_adjacencies[task2]
                )
                results.append({
                    'task1': task1,
                    'task2': task2,
                    'jaccard_index': overlap['jaccard']['jaccard_index'],
                    'pearson_r': overlap['correlation']['pearson_r']
                })

        return pd.DataFrame(results)

In [12]:
#Temporal GNN & Info Flow Integration
class TemporalGraphBuilder:
    """
    Build temporal graph sequences from fMRI data for GNN processing.

    Creates PyG Data objects with:
    - Node features from BOLD signal
    - Edge connections from GCA adjacency
    - Edge weights from F-statistics
    - HMM state labels as graph attributes
    """

    def __init__(self, n_nodes: int = 360):
        self.n_nodes = n_nodes

    def build_temporal_graphs(self,
                              trial_data: np.ndarray,
                              gca_result: Dict,
                              hmm_result: Dict) -> List:
        """
        Build a sequence of graphs for each timepoint.

        Args:
            trial_data: (n_nodes, n_timepoints) BOLD data
            gca_result: Output from compute_gca_for_trial()
            hmm_result: Output from run_hmm_analysis()

        Returns:
            List of PyG Data objects
        """
        if not HAS_PYG:
            raise ImportError("torch_geometric required. Install with: pip install torch_geometric")

        n_nodes, n_timepoints = trial_data.shape
        adjacency = gca_result['adjacency']
        f_stats = gca_result['f_statistics']
        state_sequence = hmm_result['state_sequence']

        # Convert adjacency to edge_index format
        edge_indices = np.array(np.where(adjacency > 0))
        edge_index = torch.tensor(edge_indices, dtype=torch.long)

        # Edge weights from F-statistics
        edge_weights = f_stats[adjacency > 0]
        edge_attr = torch.tensor(edge_weights, dtype=torch.float).unsqueeze(1)

        # Build graph for each timepoint
        graphs = []
        for t in range(n_timepoints):
            # Node features = BOLD signal at this timepoint
            x = torch.tensor(trial_data[:, t], dtype=torch.float).unsqueeze(1)

            # Create PyG Data object
            data = Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                hmm_state=int(state_sequence[t]) if t < len(state_sequence) else 0
            )
            graphs.append(data)

        return graphs


class TemporalGNN(nn.Module):
    """
    Temporal Graph Neural Network for learning information flow patterns.

    Architecture:
    1. GAT layers for spatial message passing (across brain regions)
    2. LSTM for temporal dynamics (across timepoints)
    3. Graph-level readout for trial embedding

    This is Step 3 of the Information Flow Module.
    """

    def __init__(self,
                 in_channels: int = 1,
                 hidden_channels: int = 64,
                 out_channels: int = 32,
                 num_gat_layers: int = 2,
                 num_heads: int = 4,
                 lstm_hidden: int = 128,
                 lstm_layers: int = 2,
                 dropout: float = 0.2):
        super(TemporalGNN, self).__init__()

        if not HAS_PYG:
            raise ImportError("torch_geometric required")

        self.hidden_channels = hidden_channels
        self.out_channels = out_channels

        # GAT layers for spatial processing
        self.gat_layers = nn.ModuleList()

        # First GAT layer
        self.gat_layers.append(
            GATConv(in_channels, hidden_channels, heads=num_heads, concat=True, dropout=dropout)
        )

        # Middle GAT layers
        for _ in range(num_gat_layers - 2):
            self.gat_layers.append(
                GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, concat=True, dropout=dropout)
            )

        # Final GAT layer
        if num_gat_layers > 1:
            self.gat_layers.append(
                GATConv(hidden_channels * num_heads, out_channels, heads=1, concat=False, dropout=dropout)
            )

        # LSTM for temporal processing
        self.lstm = nn.LSTM(
            input_size=out_channels,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0
        )

        # Output projection
        self.output_proj = nn.Linear(lstm_hidden, out_channels)

        self.dropout = nn.Dropout(dropout)

    def forward(self, graph_sequence: List) -> Dict[str, torch.Tensor]:
        """
        Process a sequence of temporal graphs.

        Args:
            graph_sequence: List of PyG Data objects (one per timepoint)

        Returns:
            Dict with node embeddings and graph embedding
        """
        # Process each graph through GAT layers
        node_embeddings_sequence = []

        for data in graph_sequence:
            x = data.x
            edge_index = data.edge_index

            # Apply GAT layers
            for i, gat_layer in enumerate(self.gat_layers):
                x = gat_layer(x, edge_index)
                if i < len(self.gat_layers) - 1:
                    x = F.elu(x)
                    x = self.dropout(x)

            node_embeddings_sequence.append(x)

        # Stack temporal sequence: (n_nodes, n_timepoints, out_channels)
        node_temporal = torch.stack(node_embeddings_sequence, dim=1)

        # Apply LSTM across time for each node
        n_nodes = node_temporal.shape[0]
        lstm_out, (h_n, _) = self.lstm(node_temporal)

        # Final node embeddings (last LSTM hidden state)
        node_embeddings = self.output_proj(h_n[-1])

        # Graph-level embedding (mean pooling over nodes)
        graph_embedding = torch.mean(node_embeddings, dim=0, keepdim=True)

        return {
            'node_embeddings': node_embeddings,
            'graph_embedding': graph_embedding,
            'temporal_features': lstm_out
        }


class WoMAD_info_flow(nn.Module):
    """
    Complete Information Flow Module for WoMAD.

    Integrates:
    1. GCA (Granger Causality Analysis) - effective connectivity
    2. HMM (Hidden Markov Model) - brain state detection
    3. Temporal GNN - spatiotemporal learning
    4. Graph Overlap - task comparison
    5. Topological Similarity - network metrics
    """

    def __init__(self, config: dict):
        super(WoMAD_info_flow, self).__init__()

        self.target_nodes = config.get("target_nodes", 360)
        self.hidden_size = config.get("hidden_size", 128)

        # GCA parameters
        self.gca_max_lag = config.get("gca_max_lag", 1)
        self.gca_significance = config.get("gca_significance", 0.05)

        # HMM parameters
        self.hmm_n_states = config.get("hmm_n_states", 3)

        # Temporal GNN (if available)
        if HAS_PYG:
            self.temporal_gnn = TemporalGNN(
                in_channels=1,
                hidden_channels=64,
                out_channels=self.hidden_size,
                num_gat_layers=2,
                num_heads=4
            )
        else:
            self.temporal_gnn = None

        # Graph builder
        self.graph_builder = TemporalGraphBuilder(self.target_nodes) if HAS_PYG else None

        # Analysis modules (non-learnable)
        self.gca_analyzer = GrangerCausalityAnalysis(
            max_lag=self.gca_max_lag,
            significance_level=self.gca_significance
        )
        self.topo_analyzer = TopologicalSimilarity(self.target_nodes)
        self.overlap_analyzer = GraphOverlapAnalysis(self.target_nodes)

    def analyze_trial(self, trial_data: np.ndarray) -> Dict[str, Any]:
        """
        Run complete information flow analysis on a single trial.

        Args:
            trial_data: fMRI data (n_nodes, n_timepoints)

        Returns:
            Dict with all analysis results
        """
        results = {}

        # Step 1: GCA
        gca_result = compute_gca_for_trial(
            trial_data,
            max_lag=self.gca_max_lag,
            significance_level=self.gca_significance
        )
        results['gca'] = gca_result

        # Step 2: HMM
        hmm_result = run_hmm_analysis(trial_data, n_states=self.hmm_n_states)
        results['hmm'] = hmm_result

        # Step 3: State-wise connectivity
        state_connectivity = compute_state_connectivity(
            trial_data, hmm_result['state_sequence']
        )
        results['state_connectivity'] = state_connectivity

        # Step 4: Topological metrics
        topo_metrics = {
            'degree': self.topo_analyzer.degree_distribution(gca_result['adjacency']),
            'clustering': self.topo_analyzer.clustering_coefficient(gca_result['adjacency']),
            'global_efficiency': self.topo_analyzer.global_efficiency(gca_result['adjacency']),
            'modularity': self.topo_analyzer.modularity(gca_result['adjacency'])
        }
        results['topology'] = topo_metrics

        return results

    def forward(self, trial_data: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass through the learnable components.

        Args:
            trial_data: Tensor (n_nodes, n_timepoints)

        Returns:
            Dict with embeddings from Temporal GNN
        """
        # Convert to numpy for GCA/HMM
        data_np = trial_data.detach().cpu().numpy()

        # Run analysis
        gca_result = compute_gca_for_trial(data_np, max_lag=self.gca_max_lag)
        hmm_result = run_hmm_analysis(data_np, n_states=self.hmm_n_states)

        # Build graphs and run GNN
        if self.temporal_gnn is not None and self.graph_builder is not None:
            graphs = self.graph_builder.build_temporal_graphs(data_np, gca_result, hmm_result)
            gnn_output = self.temporal_gnn(graphs)
            return gnn_output
        else:
            # Return placeholder if GNN not available
            return {
                'node_embeddings': torch.zeros(data_np.shape[0], self.hidden_size),
                'graph_embedding': torch.zeros(1, self.hidden_size)
            }

In [13]:
from torch_geometric.nn import GATConv

class WoMAD_info_flow(nn.Module):
    def __init__(self, config: dict):
        super(WoMAD_info_flow, self).__init__()
        self.target_nodes = 360 # Defined in your config [cite: 374, 442]
        self.hidden_size = 128 # Defined in your LSTM config [cite: 61, 439]

        # The GNN Layer: Learns directed attention between parcels
        # in_channels: BOLD activity at each node
        # out_channels: The hidden representation of "info flow"
        self.gnn_manifold = GATConv(in_channels=1,
                                    out_channels=self.hidden_size,
                                    heads=4,
                                    concat=True)

        # Linear layer to condense multi-head attention back to hidden_size
        self.post_gnn = nn.Linear(self.hidden_size * 4, self.hidden_size)

    def forward(self, x, edge_index, edge_attr):
        """
        x: BOLD signal (batch, nodes, 1)
        edge_index: The directed connections from GCA (2, num_edges)
        edge_attr: The strength of those connections (num_edges, 1)
        """
        # Step 3: Message Passing
        # This passes information along the GCA-defined directed paths
        out = self.gnn_manifold(x, edge_index, edge_attr)
        out = F.elu(self.post_gnn(out))

        return out

#### Core submodule

In [15]:
# Core module
class WoMAD_core(nn.Module):
    def __init__(self, config: dict):
        """
        Sets up the complete WoMAD model with all modules and submodules.
        (Each module and submodule includes a dynamic input layer that matches the size of input.)
        """
        super().__init__()

        target_nodes = WoMAD_config.target_parcellation         # 360, Temporary.
        timepoints = WoMAD_config.target_timepoints             # 20, Temporary.

        lstm_h_size = WoMAD_config.lstm_config["hidden_size"]   # 128

        conv4d_out_size = 64                                    # From simplified config

        # Dynamic Adapter
        self.dyn_input_adapter = DynamicInput(target_nodes = 360)

        # Core Module
        ## Submodule A: 3D-UNet
        ## Input shape = (batch, target_nodes, timepoints)
        self.segment_and_label = nn.Sequential(
            nn.Conv1d(in_channels = target_nodes,
                      out_channels = target_nodes,
                      kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Identity()          # FIX: Placeholder should be replaced with the 3D-UNet.
        )

        ## Parallel submodule B-1: LSTM (Temporal features)
        self.temporal_lstm = nn.LSTM(input_size  = target_nodes,
                                     hidden_size = lstm_h_size,
                                     num_layers  = WoMAD_config.lstm_config["num_layers"],
                                     dropout     = WoMAD_config.lstm_config["dropout"],
                                     batch_first = True)

        ## Parallel submodule B-2: ConvNet4D (Spatiotemporal features)
        self.spatiotemporal_cnv4d = nn.Sequential(
            nn.Conv2d(in_channels  = 1,
                      out_channels = 32,
                      kernel_size  = (3, 3), padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = (2, 2)),
            nn.Conv2d(in_channels  = 32,
                      out_channels = conv4d_out_size,
                      kernel_size  = (3, 3), padding = 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )

        ## Submodule C: Fusion Layer
        fusion_input_size = lstm_out_size + conv4d_out_size     # 192
        self.fusion_block = nn.Sequential(
            nn.Linear(fusion_input_size, WoMAD_config.fusion_config["hidden_size"]),
            nn.ReLU()
        )

        ### Overall, WM-based activity score:
        self.overall_activity_score = nn.Linear(WoMAD_config.fusion_config["hidden_size"], 1)

        ### Node-based (voxel-based or parcel-based) activity scores:
        self.node_wise_activity_scores = nn.Linear(WoMAD_config.fusion_config["hidden_size"], target_nodes)

    def _prepare_4d_data(self, input: torch.Tensor) -> torch.Tensor:
        """
        Helper method to create the 4D network for the second module.

        Input:
            Tensor with shape (batch, target_nodes, timepoints).
            target_nodes is the flat spatial dimension.

        Output:
            5D tensor for the 4D ConvNet with
            shape (batch, C=1, timepoints, X, Y, Z)
        """
        batch, nodes, timepoints = input.shape
        # TODO: Create the mapping array to place nodes into a X*Y*Z grid.

        four_dim_data = input.unsqueeze(1)

        return four_dim_data

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        The forward pass that manages how the data passes through modules.

        Sequence for forward pass:
            Input -> Segmentation (3D-UNet) -> (LSTM, Conv4D) -> Fusion

        Input:
            Tensor with shape (batch, current_nodes, timepoints)
        """
        # Dynamic Input Adaption
        x_dynamically_adapted = self.dyn_input_adapter(input)

        # 3D-UNet
        unet_out_timeseries = self.segment_and_label(x_dynamically_adapted)

        # LSTM
        x_for_lstm = unet_out_timeseries.transpose(1, 2)
        _, (h_n, _) = self.temporal_lstm(x_for_lstm)
        lstm_out = h_n[-1]

        # ConvNet4D
        x_for_conv4d = self._prepare_4d_data(unet_out_timeseries)
        conv4d_out = self.spatiotemporal_cnv4d(x_for_conv4d)

        # Fusion Layer
        fused_feats = torch.cat([lstm_out, conv4d_out], dim = 1)
        shared_features = self.shared_fusion_block(fused_feats)

        overall_score = self.overall_activity_score(shared_features)
        node_scores   = self.node_wise_activity_scores(shared_features)

        return overall_score, node_scores


def model_config(config: dict) -> WoMAD_core:
    """
    Initialized WoMAD and moves it to the device.

    Argument:
        config (dict): WoMAD config dictionary

    Returns:
        WoMAD: Model ready to be trained.
    """
    model = WoMAD_core(config)

    if config["system"]["use_gpu"] and torch.cuda.is_available():
        model.cuda()

    return model

### Hyperparameter Module - NOT FINALIZED

In [14]:
def define_search_space(trial: optuna.Trial) -> Dict[str, Any]:
    """
    Define the search space for hyperparameter optimization.

    Arguments:
        trial (optuna.Trial): The trial object that suggests parameters

    Returns:
        Dict[str, Any]: A dictionary of suggested hyperparameters.
    """
    # Main parameters (Learning rate, batch size, number of epochs)
    learning_rate = 0
    batch_size = 0
    epochs = 0
    # Model-specific parameters (Hidden layers, dropout rates)
    hidden_layers = 0
    dropout_rate = 0

    suggested_parameters = {
        "learning_rate": learning_rate,
        "batch_size"   : batch_size,
        "epochs"       : epochs,
        "hidden_layers": hidden_layers,
        "dropout_rate" : dropout_rate
    }

def objective(trial: optuna.Trial) -> float:
    """
    TO DO: Define the objective function for Optuna.
    """
    hyperparameters = define_search_space(trial)

    # config = WoMAD_config.load_config()
    config["training"].update(hyperparameters)

    final_valid_metric = run_pipeline(config)

    return final_valid_metric

def run_hyperparameter_optim():
    """
    TO DO: Define the main hyperparameter search function.
    """
    # Define target for optimization (min loss, min MSE, etc.)

    # Print and save the results (best trial, best parameters, best target metric)

    # Save to file as well

if __name__ == "__main__":
    run_hyperparameter_optim()

### Model Train Module

In [19]:
import numpy as np
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

from sklearn.model_selection import KFold

# Removed relative imports. WoMAD_core and run_valid_epoch are assumed to be
# globally available from previously executed cells. WoMAD_config references
# are handled by assembling a comprehensive 'config' dictionary.
# from . import WoMAD_config
# from .model_setup_module import DynamicInput, WoMAD_core
# from .model_valid_module import run_valid_epoch

def WoMAD_optimizer(model: nn.Module, config: dict) -> torch.optim.Optimizer:
    """
    Configures the optimizer for WoMAD model.
    """
    # Access training_config from the passed config dictionary
    lr = config["training_config"]["learning_rate"]
    return torch.optim.Adam(model.parameters(), lr = lr)

def WoMAD_loss_function(config: dict) -> Dict[str, nn.Module]:
    """
    Create a dictionary of loss functions.
    """
    loss_func_dict = {
        "overall_score_loss": nn.MSELoss(),
        "node_score_loss"   : nn.MSELoss()
    }
    return loss_func_dict

def run_training_epoch(model: WoMAD_core, data_loader: DataLoader,
                       optimizer: torch.optim.Optimizer,
                       loss_funcs: Dict[str, nn.Module],
                       epoch: int, config: dict):
    """
    Function to run a single training epoch.
    """
    model.train()
    total_train_loss = 0
    total_samples = 0

    # Access training_loss_weights from the passed config dictionary
    overall_weight = config["training_loss_weights"]["overall_loss_weight"]
    node_weight    = config["training_loss_weights"]["node_loss_weight"]

    overall_loss_fn = loss_funcs["overall_score_loss"]
    node_loss_fn    = loss_funcs["node_score_loss"]

    for batch_indx, (data, overall_target, node_target) in enumerate(data_loader):
        overall_target = overall_target.float()
        node_target    = node_target.float()

        optimizer.zero_grad()

        # Forward pass to return (overall and node-wise prediction)
        overall_pred, node_pred = model(data)

        # Calculate losses
        loss_overall = overall_loss_fn(overall_pred.squeeze(), overall_target)
        loss_nodes   = node_loss_fn(node_pred, node_target)

        combined_loss = (overall_weight * loss_overall) + (node_weight * loss_nodes)

        # Backpropagate
        combined_loss.backward()
        optimizer.step()

        total_train_loss += combined_loss.item() * data.size(0)
        total_samples += data.size(0)

    avg_loss = total_train_loss / total_samples
    print(f"Epoch {epoch+1:02d} | Training Loss: {avg_loss: .6f}")
    return avg_loss

def run_kfold_training(dataset, config_in: dict):
    """
    Executes K-fold cross validation for training.

    Arguments:
        dataset (Dataset): The WoMAD data which contains all target subject data.
        config_in  (dict): Configuration dictionary (optional, can be empty or partial).

    Returns:
        List of dictionaries with training stats for each training fold.
    """
    # Create a comprehensive config dictionary based on global variables and passed config_in
    # This ensures all necessary parameters are available for model initialization and training
    config = {
        "training_loss_weights": training_loss_weights, # global variable from earlier cell
        "lstm_config": lstm_config, # global variable from earlier cell
        "fusion_config": fusion_config, # global variable from earlier cell
        "target_parcellation": TARGET_NODE_COUNT, # global variable from earlier cell
        "target_timepoints": 20, # Default value, if not explicitly passed/defined elsewhere
        "system": config_in.get("system", {"use_gpu": False}) # Default system config
    }

    # Define 'training_config' with defaults and then update with any provided in config_in
    default_training_config = {
        "learning_rate": 0.001,
        "k_folds": 5,
        "num_epochs": 10,
        "batch_size": 32
    }
    config["training_config"] = default_training_config
    if "training_config" in config_in:
        config["training_config"].update(config_in["training_config"])

    k_folds = config["training_config"]["k_folds"]
    num_epochs = config["training_config"]["num_epochs"]
    batch_size = config["training_config"]["batch_size"]

    kfold = KFold(n_splits = k_folds, shuffle = True, random_state = 42)
    all_kfold_train_stats = []

    loss_funcs = WoMAD_loss_function(config)

    print(f"K-fold cross-validation for {k_folds} folds over {len(dataset)} trials:")

    for fold, (train_indx, valid_indx) in enumerate(kfold.split(dataset)):
        print(f"\nFold {fold + 1}/{k_folds}:")

        train_subset = Subset(dataset, train_indx)
        valid_subset = Subset(dataset, valid_indx)

        train_loader = DataLoader(train_subset, batch_size = batch_size, shuffle = True)
        valid_loader = DataLoader(valid_subset, batch_size = batch_size, shuffle = False)

        print(f"Train samples: {len(train_subset)}, Validation samples: {len(valid_subset)}")

        # Model initiation and setup
        # WoMAD_core is assumed to be globally available.
        model = WoMAD_core(config) # Pass the comprehensive config
        # TODO: Add the device logic (model.cuda()) - this should use config["system"]["use_gpu"]
        optimizer = WoMAD_optimizer(model, config) # Pass config

        fold_history = {"train_loss"  : [],
                        "valid_loss"  : [],
                        "val_metrics" : []}

        for epoch in range(num_epochs):
            train_loss = run_training_epoch(model, train_loader, optimizer, loss_funcs, epoch, config) # Pass config
            fold_history["train_loss"].append(train_loss)

            # run_valid_epoch is assumed to be globally available.
            # Pass the config dictionary to run_valid_epoch
            valid_loss, val_metrics = run_valid_epoch(model, valid_loader, loss_funcs, epoch, config)

            fold_history["valid_loss"].append(valid_loss)
            fold_history["val_metrics"].append(val_metrics)

        all_kfold_train_stats.append({"fold": fold + 1, "history": fold_history})

        print("\nK-fold training complete.")

    return all_kfold_train_stats

### Model Valid Module

In [21]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader

from typing import Dict, Tuple, Any

# WoMAD_config import is removed as config is now passed as an argument
# from . import WoMAD_config

def calc_graph_overlap():
    """
    TO DO: Create function to calculate graph overlap for Info-Flow module.
    """
    # Create graph network with the info-flow output.
    # Calculate graph overlap:
    ## Step 1: Normalization (Nodes and weights)
    ## Step 2: Calculate overlap (Edge-wise percentage with thresholding and Jaccard Indx
    ##         and Weighted correlation with vectorization and Pearson's
    ## Step 3: Analyze topological similarity (Compare architecture using key topological metrics)

def calc_dice_coeff(predicted_labels: torch.Tensor, true_labels: torch.Tensor):
    """
    Calculate the Dice coefficient (F1 score) for segmentation-like tasks.

    Args:
        predicted_labels (torch.Tensor): Predicted binary labels.
        true_labels (torch.Tensor): Ground truth binary labels.

    Returns:
        float: Dice coefficient.
    """
    if predicted_labels.shape != true_labels.shape:
        raise ValueError("Predicted and true labels must have the same shape.")

    intersection = (predicted_labels * true_labels).sum()
    union = predicted_labels.sum() + true_labels.sum()

    dice_coeff = (2. * intersection) / (union) if union > 0 else 1.0  # Handle case with no positive labels
    return dice_coeff.item()

def calc_r_sqrd(predictions: torch.Tensor, targets: torch.Tensor):
    """
    Calculate the R^2 score.

    Args:
        predictions (torch.Tensor): Predicted values.
        targets (torch.Tensor): True values.

    Returns:
        float: R^2 score.
    """
    ss_total = torch.sum((targets - targets.mean()) ** 2)
    ss_residual = torch.sum((targets - predictions) ** 2)

    if ss_total == 0:
        return 1.0 - ss_residual / (ss_total + 1e-8) # Add epsilon to avoid division by zero

    r_sqrd = 1 - (ss_residual / ss_total)
    return r_sqrd.item()

def calc_all_metrics(overall_pred: torch.Tensor,
                       node_pred: torch.Tensor,
                       overall_target: torch.Tensor,
                       node_target: torch.Tensor,
                       config: dict  # Add config as an argument here
                       ) -> Dict[str, float]:
    """
    Calculate all relevant metrics.
    Args:
        overall_pred (torch.Tensor): Predicted overall scores.
        node_pred (torch.Tensor): Predicted node-wise scores.
        overall_target (torch.Tensor): True overall scores.
        node_target (torch.Tensor): True node-wise scores.
        config (dict): Configuration dictionary.

    Returns:
        Dict[str, float]: Dictionary of calculated metrics.
    """
    # For simplicity, using MSE for now, but Dice and R^2 could be integrated based on specific needs
    mse_overall = F.mse_loss(overall_pred, overall_target).item()
    mse_node = F.mse_loss(node_pred, node_target).item()

    r_sqrd_overall = calc_r_sqrd(overall_pred, overall_target)
    r_sqrd_node = calc_r_sqrd(node_pred, node_target)

    # Example of how Dice could be used, assuming binary interpretation or thresholding
    # For this current setup (regression-like output), Dice might not be directly applicable
    # dice_overall = calc_dice_coeff((overall_pred > 0.5).float(), (overall_target > 0.5).float())

    metrics_dict = {
        "MSE_overall": mse_overall,
        "MSE_node": mse_node,
        "R_squared_overall": r_sqrd_overall,
        "R_squared_node": r_sqrd_node
        # "Dice_coefficient_overall": dice_overall, # if applicable
    }

    return metrics_dict

def run_valid_epoch(model: nn.Module, data_loader: DataLoader,
                    loss_funcs: Dict[str, nn.Module],
                    epoch: int, config: dict # config added as argument
                    ) -> Tuple[float, Dict[str, float]]:
    """
    Runs one validation epochs and calculates loss and metrics.
    """
    model.eval()
    total_val_loss = 0
    total_samples = 0

    all_overall_pred = []
    all_overall_target = []
    all_node_pred = []
    all_node_target = []

    # Access training_loss_weights from the passed config dictionary
    overall_weight = config["training_loss_weights"]["overall_loss_weight"]
    node_weight    = config["training_loss_weights"]["node_loss_weight"]

    overall_loss_fn = loss_funcs["overall_score_loss"]
    node_loss_fn    = loss_funcs["node_score_loss"]

    with torch.no_grad():
        for data, overall_target, node_target in data_loader:
            overall_target = overall_target.float()
            node_target = node_target.float()

            overall_pred, node_pred = model(data)

            loss_overall = overall_loss_fn(overall_pred.squeeze(), overall_target)
            loss_node    = node_loss_fn(node_pred, node_target)
            combined_loss = (overall_weight  * loss_overall) + (node_weight * loss_node)

            total_val_loss += combined_loss.item() * data.size(0)
            total_samples  += data.size(0)

            all_overall_pred.append(overall_pred)
            all_node_pred.append(node_pred)

            all_overall_target.append(overall_target)
            all_node_target.append(node_target)

    avg_loss = total_val_loss / total_samples

    print(f"Epoch {epoch+1:02d} | Validation Loss: {avg_loss: .6f}")

    final_overall_pred   = torch.cat(all_overall_pred).squeeze()
    final_overall_target = torch.cat(all_overall_target)

    final_node_pred   = torch.cat(all_node_pred)
    final_node_target = torch.cat(all_node_target)

    metrics = calc_all_metrics(final_overall_pred,
                               final_node_pred,
                               final_overall_target,
                               final_node_target,
                               config) # Pass config to calc_all_metrics

    print(f"Validation metrics: {metrics}")

    return avg_loss, metrics


### Result Interpretation Module

In [23]:
import numpy as np
import shap
import torch
from typing import Dict, Any

# Removed relative import, as config should be passed as an argument
# from . import WoMAD_config

def predict():
    """
    TO DO: Create the functions that allows us to use the model for inference.
    """
    return "prediction_placeholder" # Placeholder for actual prediction

def visualize_and_interpret(model: WoMAD_core,
                            data_loader: DataLoader,
                            config: dict):
    """
    Generating figures and saliency maps for interpreting the results.
    """
    model.eval()

    # Visualize output (predicted vs. actual score)

    # SHAP analysis based on timeseries and final fused outputs

    # Save visuals
    return 0

def run_pipeline_with_valid_dataset(model: WoMAD_core,
                                    valid_loader: DataLoader,
                                    loss_funcs: Dict[str, nn.Module],
                                    config: dict):
    """
    Running full validation and analysis after training.
    """
    model.eval()

    # Validation loop (this 'which_function_is_this' needs to be replaced or defined)
    # Assuming 'run_valid_epoch' is the intended function here based on context
    # and it returns (avg_loss, metrics), from which predictions can be derived.
    # For now, let's assume predict() is used or a similar mechanism.
    # Placeholder for the actual prediction extraction:
    # all_preds, all_targets = which_function_is_this(model, valid_loader, loss_funcs, config)

    # To resolve the immediate 'which_function_is_this' NameError and assume some pred/target
    # For a real pipeline, this would involve running the model on the valid_loader
    # and collecting all predictions and targets.
    all_overall_preds = []
    all_node_preds = []
    all_overall_targets = []
    all_node_targets = []

    with torch.no_grad():
        for data, overall_target, node_target in valid_loader:
            overall_pred, node_pred = model(data)
            all_overall_preds.append(overall_pred)
            all_node_preds.append(node_pred)
            all_overall_targets.append(overall_target)
            all_node_targets.append(node_target)

    final_overall_pred = torch.cat(all_overall_preds).squeeze()
    final_node_pred = torch.cat(all_node_preds)
    final_overall_target = torch.cat(all_overall_targets)
    final_node_target = torch.cat(all_node_targets)

    all_preds = {"overall": final_overall_pred, "node": final_node_pred}
    all_targets = {"overall": final_overall_target, "node": final_node_target}


    # Metrics
    metrics = calc_all_metrics(all_preds["overall"],
                               all_preds["node"],
                               all_targets["overall"],
                               all_targets["node"],
                               config)
    print(f"Final validation metrics: {metrics}")

    # Visuals
    visualize_and_interpret(model, valid_loader, config)

    print("Validation complete.")

    # Save final metrics
    # print(f"Outputs saved to: {output_dir}")

    return metrics

## TESTS

In [25]:
# TODO: Create dummy data OR sample a tiny subset of the actual dataset
# TODO: Create tests for each and every single function or method.

# TESTS:
def test_data_module():
    pass

def test_model_setup():
    pass

def test_model_training():
    pass

def test_model_validation():
    pass

def test_result_interpretation():
    pass

## WoMAD Main

In [27]:
# Terminal functions and UI
## print status, success, or error
## clear screen
## Welcome/Completion

def run_WoMAD(config):
    # Environment setup
    # Data and initial processing
    # Model setup
    # Training (and hyperparameter search)
    # Post-training: Analysis, Visualization, and Interpretation
    pass # Added pass statement

if __name__ == "__main__":
    # Define a placeholder config using existing global variables for demonstration
    # In a real application, this would be loaded from a config file.
    config = {
        "training_loss_weights": training_loss_weights,
        "lstm_config": lstm_config,
        "fusion_config": fusion_config,
        "target_parcellation": TARGET_NODE_COUNT,
        "target_timepoints": 20, # Assuming a default, adjust as needed
        "system": {"use_gpu": torch.cuda.is_available()},
        "training_config": { # Add a default training config as it's used elsewhere
            "learning_rate": 0.001,
            "k_folds": 5,
            "num_epochs": 10,
            "batch_size": 32
        }
    }
    run_WoMAD(config)

NameError: name 'config' is not defined