# WoMAD Development Notebook

### Setup and Configurations

In [None]:
# Install basic dependencies
!pip install --quiet requests nilearn nibabel brainspace numpy pandas scikit-learn torch torch-geometric scipy rich>=13.5.2
!pip install --quiet optuna

# Install torch_geometric and its dependencies
!pip install --quiet torch_geometric torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

import os
import io
import time
import zipfile
from typing import List, Dict, Tuple, Any, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as PyG_Data
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import nibabel as nib

import shap   # For result interp module
import pytest # For TESTS module
import optuna # For hyperparameter module

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch_cluster (setup.py) ... [?25l[?25hdone
  Building wheel for torch_spline_conv (setup.py) ... [?25l[?25hdone


AttributeError: partially initialized module 'fsspec' has no attribute 'utils' (most likely due to a circular import)

In [None]:
# Project Paths:
import os
from google.colab import drive; drive.mount('/content/drive') # Temporary for Colab
project_root = "/content/drive/MyDrive/WoMAD/WoMAD/Notebooks/data" # Google Colab-specific (different in the .py files and the repo)

unprocessed_path = os.path.join(project_root, "HCP_zipped")
processed_path   = os.path.join(project_root, "processed")
model_ready_path = os.path.join(project_root, "model_ready")

sub_list_txt_path = os.path.join(project_root, "full_3T_task_subjects.txt")

# WoMAD-specific variables:
target_tasks = ["WM", "EMOTION", "LANGUAGE"]

target_subtasks = {
    "WM"      : ["0bk_body", "0bk_faces", "0bk_places", "0bk_tools",
                 "2bk_body", "2bk_faces", "2bk_places", "2bk_tools"],
    "EMOTION" : ["fear", "neut"],
    "LANGUAGE": ["math", "story"],
}

TR = 0.72

rest_tasks    = ["REST1", "REST2"]
run_direction = ["LR"   , "RL"]

# Subjects with full 3T imaging protocol completed:
full_3T_task_subjects = []

with open(sub_list_txt_path, "r") as file:
    raw_list = file.read()
    str_list = raw_list.strip().split(",")
    num_list = [int(subID.strip()) for subID in str_list if subID.strip()]

full_3T_task_subjects = num_list

# Model config dictionaries
lstm_config = {
    "hidden_size" : 128,
    "num_layers"  :   2,
    "dropout"     : 0.2
}

fusion_config = {
    "total_input_feats" : 736, # Temporary: 360 parcels, 128 from the LSTM, and 248 guessing from the 4D network
    "hidden_size"       : 128
}

training_loss_weights = {
    "overall_loss_weight" : 0.5,
    "node_loss_weight"    : 0.5
}

# Randomly selected subject list for the pilot:
pilot_subjects = [283543, 180937, 379657, 145632, 100206, 270332, 707749, 454140, 194847, 185038]

# Temporary variables for development:
dev_subjects = [100206, 100408]

Mounted at /content/drive


In [None]:
#GCA configurations
gca_config = {
    "max_lag": 1,
    "significance_level": 0.05,
    "pairwise": True
}

#HMM configurations
hmm_config = {
    "n_states": 3,
    "n_iter": 100,
    "covariance_type": "full"
}


### Data Module

In [None]:
# File Access:
def generate_paths(task : str = target_tasks[0],
                   run  : str = run_direction[0],
                   sub_list : list = dev_subjects):
    """
    Uses the paths configured in WoMAD_config to create subject-specific paths.

    Arguments:
        task             (str): The target task from HCP -> [W]orking [M]emory as default
        run              (str): RL or LR -> RL as our arbitrary default
        sub_list (list of int): List of target subject ID's as defined in the config file

    Returns:
        Dictionary with this integer and tuple-of-strings format:
        paths = {
            subject_id (int) : ("Main 'Results' file path", "EV file path")
        }
    """
    paths = {}

    # General path format for each subject's directory:
    # (f"../data/HCP_zipped/{subject-ID}/MNINonLinear/Results/")

    # General path format for subjects' task EV files:
    # (f"../data/HCP_zipped/{subject-ID}/MNINonLinear/Results/tfMRI_{TASK}_{RUN}/EVs/")

    # List of target subjects: full_3T_task_subjects (imported from WoMAD_config)
    for subject in sub_list:
        subject_path    = f"../data/HCP_zipped/{subject}/MNINonLinear/Results/"
        subject_ev_path = f"../data/HCP_zipped/{subject}/MNINonLinear/Results/tfMRI_{task}_{run}/EVs/"
        paths[subject]  = (subject_path, subject_ev_path)

    return paths


def load_data_from_path(task : str = target_tasks[0],
                        run  : str = run_direction[0],
                        subject : str = dev_subjects[0],
                        subtask : str = target_subtasks["WM"][0]):
    """
    Reads the contents of each subject's files.

    Arguments:
        task    (str): The target task from HCP -> [W]orking [M]emory as default
        run     (str): RL or LR -> RL as our arbitrary default
        subject (int): ID of specific target subject
        subtask (str): The target subtask in string format -> Example: "0bk_tools"

    Returns:
        Dictionary of {Subject: (Tuple of fMRI data)} and
        EV file contents assigned to the ev_file variable.
    """
    try:
        paths = generate_paths(task = task, run = run)
        bold_ts_path = paths[subject][0] + f"tfMRI_{task}_{run}/tfMRI_{task}_{run}_Atlas_MSMAll_hp0_clean_rclean_tclean.dtseries.nii"
        ev_file_path = paths[subject][1] + f"{subtask}.txt"

        with open(ev_file_path, "r") as ev:
            ev_file = ev.read()

        bold_ts = nib.load(bold_ts_path)
        bold_data = bold_ts.get_fdata()
        bold_header = bold_ts.header

        fmri_timeseries = {subject: (bold_ts, bold_header, bold_data)}

        return fmri_timeseries, ev_file

    except Exception as e:
        print("Error loading time series and EV files from path!")


# Preprocessing:
## Parse and Isolate Trials
def isolate_trials(fmri_ts, ev_file, TR : float = 0.72):
    """
    Parses through the data and isolates each task trial using EV files.

    Input: The fMRI dictionary and EV file from load_data_from_path() function.

    Returns:
        List of trials isolated using the ev_file.
    """
    trial_list = []

    for subject, (bold_ts, bold_header, bold_data) in fmri_ts.items():
        data_array = bold_data

        try:
            ev_data = np.loadtxt(io.StringIO(ev_file))
        except ValueError:
            print(f"Could not parse EV file for subject {subject}.")
            continue

        for onset, duration, _ in ev_data:
            start_idx = int(np.floor(onset / TR))
            end_idx = int(np.ceil((onset + duration) / TR))
            trial_data = data_array[:, start_idx:end_idx]

            trial_list.append({
                "subject" : subject,
                "onset" : onset,
                "duration" : duration,
                "data" : trial_data
            })

    return trial_list

## Normalization
def normalize_data(data : np.ndarray, norm_mode: str = "z_score"):
    """
    Normalizes a numpy array of fMRI time series data.

    Arguments:
        data (np.ndarray): The time series data with shape (voxels, time_points)
        norm_mode   (str): Method of normalization (Z score, min/max, etc.)

    Returns:
        Numpy array of normalized data.
    """
    data = np.array(data)

    if norm_mode == "z_score":
        ts_data_mean = np.mean(data, axis = 1, keepdims = True)
        ts_data_stdv = np.std(data , axis = 1, keepdims = True)

        ts_data_stdv[ts_data_stdv == 0] = 1.0

        normalized_ts_data = (data - ts_data_mean) / ts_data_stdv

        return normalized_ts_data

    elif norm_mode == "min_max":
        min_ts_data = np.min(data, axis = 1, keepdims = True)
        max_ts_data = np.max(data, axis = 1, keepdims = True)

        range_ts_data = max_ts_data - min_ts_data
        range_ts_data[range_ts_data == 0] = 1.0

        normalized_ts_data = (data - min_ts_data) / range_ts_data

        return normalized_ts_data

    else: # For now ...
        print(f"Normalization mode '{norm_mode}' not defined.\nReturning data as is.")
        return data

## Save to Pandas DataFrame
def save_to_df(trial_list : List[Dict[str, Any]],
               file_name : str,
               output_dir : str = processed_path):
    """
    Converts the list of isolated trials to a Pandas DF and saves it to defined path.

    Arguments:
        trial_list (list): List of {"subject", "onset", "duration", "data"} dictionaries.
        file_name   (str): Name of output file.
        output_dir  (str): Directory for saving the output file.

    Saves the pd.DataFrame to output_dir.
    """
    df_from_trial_ts = pd.DataFrame(trial_list)

    os.makedirs(output_dir, exist_ok = True)
    save_path = os.path.join(output_dir, f"{file_name}.pkl")

    df_from_trial_ts.to_pickle(save_path)

    print(f"Data saved to {save_path}")

    return df_from_trial_ts


# Initial Processing (with the WoMAD_data class):
class WoMAD_data(Dataset):
    def __init__(self,
                 task : str,
                 runs : list = run_direction,
                 subjects : list = dev_subjects,
                 output_dir : str = processed_path):
        """
        Basic configuration of the dataset.

        Arguments:
            task (str): The target task ("WM")
            runs (list): Run directions
            subjects (list): List of target subject IDs
            output_dir (str): Directory for saving processed data
        """
        self.task = task
        self.runs = runs
        self.subjects = subjects
        self.output_dir = output_dir

        self.data = []

    def __len__(self) -> int:
        """
        Returns the total number of subjects in the dataset.
        """
        return len(self.data_paths)

    def __getitem__(self, indx: int):
        """
        Loads one trial's data and returns:
            - input timeseries (X)
            - overall target (Y_overall)
            - node target (Y_node)
        """
        trial = self.data[indx]

        # Input data with shape (target_nodes, timepoints)
        data_np = trial["data"]
        data_tensor = torch.from_numpy(data_np).float()

        # Overall target:
        # TODO: Add overall target components.

        # Placeholder:
        overall_target_np = trial["stats"]["trial_mean"]
        overall_target_tensor = torch.tensor([overall_target_np]).float()

        # Node target with shape (target_nodes)
        node_target_np = trial["stats"]["mean_per_node"]
        node_target_tensor = torch.from_numpy(node_target_np).float()

        return data_tensor, overall_target_tensor, node_target_tensor

    def _load_data(self):
        """
        Load and parse the data using NIfTI and EV files.
        """
        parsed_data = []

        for subject in self.subject:
            for run in self.runs:
                paths = generate_paths(task = self.task, run = run,
                                       sub_list = [subject])

                subtasks = target_subtasks.get(self.task, [])

                for subtask in subtasks:
                    fmri_ts, ev_file = load_data_from_path(task = self.task,
                                                           run = run,
                                                           subject = subject,
                                                           subtask = subtask)
                    trial_list_subtask = isolate_trials(fmri_ts, ev_file)

                    for trial_dict in trial_list_subtask:
                        trial_dict["run"] = run
                        trial_dict["subtask"] = subtask
                        parsed_data.append(trial_dict)

        # TODO: Error handling inside the for loop.

        self.data = parsed_data

        return self.data

    def basic_processing(self, norm_mode : str = "z_score",
                         file_name : str = "processed_fMRI_data"):
        """
        Normalization and saving with normalize_data() and save_to_df().
        """
        processed_trials = []

        for trial in self.data:
            normalized_trial = normalize_data(trial["data"], norm_mode = norm_mode)

            trial["data"] = normalized_trial
            trial["norm_mode"] = norm_mode
            processed_trials.append(trial)

        file_to_save_processed_data = f"{file_name}_{self.task}_{norm_mode}"
        self.processed_df = save_to_df(processed_trials,
                                       file_to_save_processed_data,
                                       self.output_dir)

        self.data = processed_trials

        return self.processed_df

    def _calc_corr_matrix(self, trial_data: np.ndarray) -> dict:
        """
        Calculates whole-brain and network-level correlation matrices for a single isolated trial.

        NOTE:   Network-level correlations require network masks
                or specific voxel/parcel definitions.
                These have not yet been defined in the config files.

        Arguments:
            trial_data (np.ndarray): The isolated trial's time series data.

        Returns:
            A dictionary containing the calculated correlation matrix.
        """
        # All voxels (whole_brain)
        whole_brain_corr_mat = np.corrcoef(trial_data)

        whole_brain_corr_mat[np.isnan(whole_brain_corr_mat)] = 0

        # TODO: Create network-level correlation

        return {
            "whole_brain"  : whole_brain_corr_mat,
            "network_level": whole_brain_corr_mat   # Temp solution until the network-level is defined.
        }

    def calc_func_connectivity(self):
        """
        Calculates the correlation matrices for all isolated trials.

        Returns:
            The updated list of trial dictionaries with the correlation matrices added.
        """
        for trial in self.data:
            trial_ts_data = trial.get("data") # Should be normalized

            # Correlation matrix
            trial_corr_mat = self._calc_corr_matrix(trial_ts_data)

            trial["corr_matrix"] = trial_corr_mat

        return self.data

    def calc_basic_stats(self):
        """
        Calculate basic statistics (mean and std) of
        the activity for isolated trials across all nodes/voxels.

        Returns:
            Updated list of trial dictionaries with "stats" key added.
        """
        for trial in self.data:
            trial_ts_data = trial.get("data")

            if trial_ts_data is None:
                continue

            # Mean activity
            node_mean_activity = np.mean(trial_ts_data, axis = 1)

            # Standard deviation
            node_std_activity = np.std(trial_ts_data, axis = 1)

            # Overall average for entire trial
            trial_mean_activity = np.mean(node_mean_activity)
            trial_mean_of_std   = np.mean(node_std_activity)

            trial["stats"] = {
                "mean_per_node": node_mean_activity,
                "std_per_node" : node_std_activity,
                "trial_mean"   : trial_mean_activity,
                "overall_std"  : trial_mean_of_std
            }

        return self.data

    def visualize_correlations(self,
                               target_trial_idx: int = 0,
                               matrix_type: str = "whole_brain"):
        """
        Generate a heatmap visualization for one trial's correlation matrix.

        Arguments:
            target_trial_idx (int): Index of the trial in self.data
            matrix_type      (str): The type of matrix to plot
                                    "whole_brain" or "network_level"
        """
        trial = self.data[target_trial_idx]
        corr_matrices = trial.get("corr_matrix")

        matrix_to_plot = corr_matrices.get(matrix_type)

        plt.figure(figsize = (10, 8))
        sns.heatmap(matrix_to_plot,
                    cmap = "RdBu_r",
                    vmin = -1, vmax = 1,
                    square = True,
                    cbar_kws = {"label" : "Pearson Correlation Coefficient"})

        subject_id = trial.get("subject", "N/A")

        task = trial.get("task", "N/A")
        subtask = trial.get("subtask", "N/A")

        plt.title(f"{matrix_type.title()} Correlation Matrix\nSubject: {subject_id}, {task}: {subtask}, Trial {target_trial_idx}")
        plt.xlabel("Node/Voxel Index")
        plt.ylabel("Node/Voxel Index")

        plt.show()

# TODO: Add function for "validation set processing" which can process non-HCP data.

NameError: name 'np' is not defined

### Model Setup Module

### The Dynamic Input module

In [None]:
# Dynamic Adapter
TARGET_NODE_COUNT = 360       # 360 parcels based on HCP and Glasser parcellation

class DynamicInput(nn.Module):
    """
    This module handles input data with different number of voxels
    and adapts it for the modules (Info flow or Core) of the WoMAD model.
    """
    def __init__(self, target_nodes: int = TARGET_NODE_COUNT):
        super().__init__()
        self.target_nodes = target_nodes

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Dynamically adapts the input data to defined dimension.

        Input:
            x : 3D tensor with shape (batch, current_nodes, timepoints)

        Output:
            The dynamically adapted 3D tensor with shape (batch, target_nodes, timepoints)
        """
        batch, current_nodes, timepoints = x.shape

        if current_nodes == self.target_nodes:
            return x

        elif current_nodes > self.target_nodes:
            # TODO: Add voxel-to-parcel downsampling (requires a parcellation map)

            # Current solution for downsampling: Adaptive Pooling (linear projection)
            x_reshaped = x.transpose(1, 2)

            x_pooled = F.adaptive_avg_pool1d(x_reshaped, self.target_nodes)

            x_pooled_out = x_pooled.transpose(1, 2)

            return x_pooled_out

        else: # current_nodes < self.target_nodes
            required_padding = self.target_nodes - current_nodes

            x_padded = F.pad(x, (0, 0, 0, required_padding))

            return x_padded

#### Info flow submodule


In [None]:
#GCA
from scipy import stats
from typing import Tuple, Dict, Optional
import warnings


class GrangerCausalityAnalysis:
    def __init__(self, config: dict):
        """
        Sets up the complete WoMAD model with all modules and submodules.
        (Each module and submodule includes a dynamic input layer that matches the size of input.)
        """
        # Information Flow Module
        ## Effective Connectivity: GCA
        ## Dynamic Functional Connectivity: HMM
        ## Final info-flow manifold: Temporal GNN

    def __init__(self, config:dict):
        super(WoMAD_info_flow, self).__init__()
        self.target_nodes = config.get("target_nodes", 360)
        self.lstm_hidden = config["lstm_config"]["hidden_size"]

        # GCA Layer: A simple linear layer to learn directed weights between nodes
        self.gca_weights = nn.Linear(self.target_nodes, self.target_nodes, bias=False)

        # Temporal GNN components
        # Placeholder for torch_geometric layers
        # self.gnn_layer = GATConv(in_channels=..., out_channels=...)

    def _calculate_gca(self, x: torch.Tensor):
        """
        Estimates directed connectivity.
        Input shape: (batch, nodes, timepoints) [cite: 385]
        """
        # Shift time to compare past (t) with future (t+1)
        past = x[:, :, :-1]   # All timepoints except last
        future = x[:, :, 1:]  # All timepoints except first

        # Linear regression to find influence weights
        # future = weights * past
        # This simplifies the VAR(1) model for neural network integration
        directed_adj = self.gca_weights(past.transpose(1, 2))
        return directed_adj.transpose(1, 2)

    def _get_hmm_states(self, trial_data, n_states=3):
        """
        Identifies cognitive states (Dynamic Functional Connectivity).
        Note: HMM is often pre-computed or run via CPU as it is non-differentiable.
        """
        # trial_data shape: (nodes, time)
        X = trial_data.cpu().detach().numpy().T
        model = hmm.GaussianHMM(n_components=n_states, covariance_type="full", n_iter=100)
        model.fit(X)
        state_sequence = model.predict(X)
        return state_sequence

    def forward(self, input: torch.Tensor, module_selection: str):
        """
        The forward pass that manages how the data passes through modules.
        """
        if module_selection == "gca":
          return self._calculate_gca(input)

        return input


In [None]:
#HMM
!pip install --quiet hmmlearn
from hmmlearn import hmm

from sklearn.decomposition import PCA

def run_hmm_stating(trial_data, np.ndarray,
                    n_states: int = 3,
                    n_iter: int = 100,
                    n_pca_components: int = None) --> Dict:
    """
    Identifies hidden brain states using Hidden Markov Model.

    Args:
        trial_data (np.ndarray): fMRI data (n_nodes, n_timepoints).
        n_states: Number of hidden states to detect.
        n_iter: Number of EM iterations.
        n_pca_components: Number of principal components to use for PCA.

    Returns:
        Dict with state_sequence, state_means, transition_matrix
    """

    #HMM expects (n_samples, n_features) = (timepoints, nodes)
    X = trial_data.T
    n_samples, n_features = X.shape

    # Use PCA if features > samples to avoid degenerate solution
    pca_used = False
    if n_pca_components is None:
        # Keep n_features such that model parameters < data points
        max_components = min(n_samples - 1, n_features, 20)
        if n_features > max_components:
            n_pca_components = max_components

    if n_pca_components and n_pca_components < n_features:
        pca = PCA(n_components=n_pca_components, random_state=42)
        X = pca.fit_transform(X)
        pca_used = True

    # Fit HMM with diagonal covariance for efficiency
    model = hmm.GaussianHMM(
        n_components=n_states,
        covariance_type="diag",
        n_iter=n_iter,
        random_state=42
    )

    model = hmm.GaussianHMM(n_components=n_states, covariance_type="full", n_iter=100)

    model.fit(X)

    state_sequence = model.predict(X)

    return {
        'state_sequence': state_sequence,
        'state_means': model.means_,
        'transition_matrix': model.transmat_,
        'n_states': n_states,
        'pca_used': pca_used,
        'n_features_used': X.shape[1]
    }

def compute_state_connectivity(trial_data: np.ndarray,
                               state_sequence: np.ndarray) -> Dict[int, np.ndarray]:
    """
    Compute functional connectivity matrix for each HMM state.

    Args:
        trial_data: fMRI data (n_nodes, n_timepoints)
        state_sequence: HMM state labels for each timepoint

    Returns:
        Dict mapping state_id -> correlation matrix
    """
    unique_states = np.unique(state_sequence)
    state_connectivity = {}

    for state in unique_states:
        state_mask = state_sequence == state
        if np.sum(state_mask) < 2:
            continue

        # Extract timepoints for this state
        state_data = trial_data[:, state_mask]

        # Compute correlation matrix
        corr_matrix = np.corrcoef(state_data)
        corr_matrix[np.isnan(corr_matrix)] = 0

        state_connectivity[int(state)] = corr_matrix

    return state_connectivity

In [None]:
#Topological Similarity & Graph Overlap Analysis

In [None]:
#Temporal GNN & Info Flow Integration


In [None]:
from torch_geometric.nn import GATConv

class WoMAD_info_flow(nn.Module):
    def __init__(self, config: dict):
        super(WoMAD_info_flow, self).__init__()
        self.target_nodes = 360 # Defined in your config [cite: 374, 442]
        self.hidden_size = 128 # Defined in your LSTM config [cite: 61, 439]

        # The GNN Layer: Learns directed attention between parcels
        # in_channels: BOLD activity at each node
        # out_channels: The hidden representation of "info flow"
        self.gnn_manifold = GATConv(in_channels=1,
                                    out_channels=self.hidden_size,
                                    heads=4,
                                    concat=True)

        # Linear layer to condense multi-head attention back to hidden_size
        self.post_gnn = nn.Linear(self.hidden_size * 4, self.hidden_size)

    def forward(self, x, edge_index, edge_attr):
        """
        x: BOLD signal (batch, nodes, 1)
        edge_index: The directed connections from GCA (2, num_edges)
        edge_attr: The strength of those connections (num_edges, 1)
        """
        # Step 3: Message Passing
        # This passes information along the GCA-defined directed paths
        out = self.gnn_manifold(x, edge_index, edge_attr)
        out = F.elu(self.post_gnn(out))

        return out

#### Core submodule

In [None]:
# Core module
class WoMAD_core(nn.Module):
    def __init__(self, config: dict):
        """
        Sets up the complete WoMAD model with all modules and submodules.
        (Each module and submodule includes a dynamic input layer that matches the size of input.)
        """
        super().__init__()

        target_nodes = WoMAD_config.target_parcellation         # 360, Temporary.
        timepoints = WoMAD_config.target_timepoints             # 20, Temporary.

        lstm_h_size = WoMAD_config.lstm_config["hidden_size"]   # 128

        conv4d_out_size = 64                                    # From simplified config

        # Dynamic Adapter
        self.dyn_input_adapter = DynamicInput(target_nodes = 360)

        # Core Module
        ## Submodule A: 3D-UNet
        ## Input shape = (batch, target_nodes, timepoints)
        self.segment_and_label = nn.Sequential(
            nn.Conv1d(in_channels = target_nodes,
                      out_channels = target_nodes,
                      kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Identity()          # FIX: Placeholder should be replaced with the 3D-UNet.
        )

        ## Parallel submodule B-1: LSTM (Temporal features)
        self.temporal_lstm = nn.LSTM(input_size  = target_nodes,
                                     hidden_size = lstm_h_size,
                                     num_layers  = WoMAD_config.lstm_config["num_layers"],
                                     dropout     = WoMAD_config.lstm_config["dropout"],
                                     batch_first = True)

        ## Parallel submodule B-2: ConvNet4D (Spatiotemporal features)
        self.spatiotemporal_cnv4d = nn.Sequential(
            nn.Conv2d(in_channels  = 1,
                      out_channels = 32,
                      kernel_size  = (3, 3), padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = (2, 2)),
            nn.Conv2d(in_channels  = 32,
                      out_channels = conv4d_out_size,
                      kernel_size  = (3, 3), padding = 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )

        ## Submodule C: Fusion Layer
        fusion_input_size = lstm_out_size + conv4d_out_size     # 192
        self.fusion_block = nn.Sequential(
            nn.Linear(fusion_input_size, WoMAD_config.fusion_config["hidden_size"]),
            nn.ReLU()
        )

        ### Overall, WM-based activity score:
        self.overall_activity_score = nn.Linear(WoMAD_config.fusion_config["hidden_size"], 1)

        ### Node-based (voxel-based or parcel-based) activity scores:
        self.node_wise_activity_scores = nn.Linear(WoMAD_config.fusion_config["hidden_size"], target_nodes)

    def _prepare_4d_data(self, input: torch.Tensor) -> torch.Tensor:
        """
        Helper method to create the 4D network for the second module.

        Input:
            Tensor with shape (batch, target_nodes, timepoints).
            target_nodes is the flat spatial dimension.

        Output:
            5D tensor for the 4D ConvNet with
            shape (batch, C=1, timepoints, X, Y, Z)
        """
        batch, nodes, timepoints = input.shape
        # TODO: Create the mapping array to place nodes into a X*Y*Z grid.

        four_dim_data = input.unsqueeze(1)

        return four_dim_data

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        The forward pass that manages how the data passes through modules.

        Sequence for forward pass:
            Input -> Segmentation (3D-UNet) -> (LSTM, Conv4D) -> Fusion

        Input:
            Tensor with shape (batch, current_nodes, timepoints)
        """
        # Dynamic Input Adaption
        x_dynamically_adapted = self.dyn_input_adapter(input)

        # 3D-UNet
        unet_out_timeseries = self.segment_and_label(x_dynamically_adapted)

        # LSTM
        x_for_lstm = unet_out_timeseries.transpose(1, 2)
        _, (h_n, _) = self.temporal_lstm(x_for_lstm)
        lstm_out = h_n[-1]

        # ConvNet4D
        x_for_conv4d = self._prepare_4d_data(unet_out_timeseries)
        conv4d_out = self.spatiotemporal_cnv4d(x_for_conv4d)

        # Fusion Layer
        fused_feats = torch.cat([lstm_out, conv4d_out], dim = 1)
        shared_features = self.shared_fusion_block(fused_feats)

        overall_score = self.overall_activity_score(shared_features)
        node_scores   = self.node_wise_activity_scores(shared_features)

        return overall_score, node_scores


def model_config(config: dict) -> WoMAD_core:
    """
    Initialized WoMAD and moves it to the device.

    Argument:
        config (dict): WoMAD config dictionary

    Returns:
        WoMAD: Model ready to be trained.
    """
    model = WoMAD_core(config)

    if config["system"]["use_gpu"] and torch.cuda.is_available():
        model.cuda()

    return model

### Hyperparameter Module - NOT FINALIZED

In [None]:
def define_search_space(trial: optuna.Trial) -> Dict[str, Any]:
    """
    Define the search space for hyperparameter optimization.

    Arguments:
        trial (optuna.Trial): The trial object that suggests parameters

    Returns:
        Dict[str, Any]: A dictionary of suggested hyperparameters.
    """
    # Main parameters (Learning rate, batch size, number of epochs)
    learning_rate = 0
    batch_size = 0
    epochs = 0
    # Model-specific parameters (Hidden layers, dropout rates)
    hidden_layers = 0
    dropout_rate = 0

    suggested_parameters = {
        "learning_rate": learning_rate,
        "batch_size"   : batch_size,
        "epochs"       : epochs,
        "hidden_layers": hidden_layers,
        "dropout_rate" : dropout_rate
    }

def objective(trial: optuna.Trial) -> float:
    """
    TO DO: Define the objective function for Optuna.
    """
    hyperparameters = define_search_space(trial)

    # config = WoMAD_config.load_config()
    config["training"].update(hyperparameters)

    final_valid_metric = run_pipeline(config)

    return final_valid_metric

def run_hyperparameter_optim():
    """
    TO DO: Define the main hyperparameter search function.
    """
    # Define target for optimization (min loss, min MSE, etc.)

    # Print and save the results (best trial, best parameters, best target metric)

    # Save to file as well

if __name__ == "__main__":
    run_hyperparameter_optim()

### Model Train Module

In [None]:
import numpy as np
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

from sklearn.model_selection import KFold

from . import WoMAD_config

from .model_setup_module import DynamicInput, WoMAD_core
from .model_valid_module import run_valid_epoch

def WoMAD_optimizer(model: nn.Module, config: dict) -> torch.optim.Optimizer:
    """
    Configures the optimizer for WoMAD model.
    """
    lr = WoMAD_config.training_config["learning_rate"]
    return torch.optim.Adam(model.parameters(), lr = lr)

def WoMAD_loss_function(config: dict) -> Dict[str, nn.Module]:
    """
    Create a dictionary of loss functions.
    """
    loss_func_dict = {
        "overall_score_loss": nn.MSELoss(),
        "node_score_loss"   : nn.MSELoss()
    }
    return loss_func_dict

def run_training_epoch(model: WoMAD_core, data_loader: DataLoader,
                       optimizer: torch.optim.Optimizer,
                       loss_funcs: Dict[str, nn.Module],
                       epoch: int, config: dict):
    """
    Function to run a single training epoch.
    """
    model.train()
    total_train_loss = 0
    total_samples = 0

    overall_weight = WoMAD_config.training_loss_weights["overall_loss_weight"]
    node_weight    = WoMAD_config.training_loss_weights["node_loss_weight"]

    overall_loss_fn = loss_funcs["overall_score_loss"]
    node_loss_fn    = loss_funcs["node_score_loss"]

    for batch_indx, (data, overall_target, node_target) in enumerate(data_loader):
        overall_target = overall_target.float()
        node_target    = node_target.float()

        optimizer.zero_grad()

        # Forward pass to return (overall and node-wise prediction)
        overall_pred, node_pred = model(data)

        # Calculate losses
        loss_overall = overall_loss_fn(overall_pred.squeeze(), overall_target)
        loss_nodes   = node_loss_fn(node_pred, node_target)

        combined_loss = (overall_weight * loss_overall) + (node_weight * loss_nodes)

        # Backpropagate
        combined_loss.backward()
        optimizer.step()

        total_train_loss += combined_loss.item() * data.size(0)
        total_samples += data.size(0)

    avg_loss = total_train_loss / total_samples
    print(f"Epoch {epoch+1:02d} | Training Loss: {avg_loss: .6f}")
    return avg_loss

def run_kfold_training(dataset, config: dict):
    """
    Executes K-fold cross validation for training.

    Arguments:
        dataset (Dataset): The WoMAD data which contains all target subject data.
        config     (dict): Configuration dictionary.

    Returns:
        List of dictionaries with training stats for each training fold.
    """
    k_folds = WoMAD_config.training_config["k_folds"]
    num_epochs = WoMAD_config.training_config["num_epochs"]
    batch_size = WoMAD_config.training_config["batch_size"]

    kfold = KFold(n_splits = k_folds, shuffle = True, random_state = 42)
    all_kfold_train_stats = []

    loss_funcs = WoMAD_loss_function(config)

    print(f"K-fold cross-validation for {k_folds} folds over {len(dataset)} trials:")

    for fold, (train_indx, valid_indx) in enumerate(kfold.split(dataset)):
        print(f"\nFold {fold + 1}/{k_folds}:")

        train_subset = Subset(dataset, train_indx)
        valid_subset = Subset(dataset, valid_indx)

        train_loader = DataLoader(train_subset, batch_size = batch_size, shuffle = True)
        valid_loader = DataLoader(valid_subset, batch_size = batch_size, shuffle = False)

        print(f"Train samples: {len(train_subset)}, Validation samples: {len(valid_subset)}")

        # Model initiation and setup
        model = WoMAD_core(config)
        # TODO: Add the device logic (model.cuda())
        optimizer = WoMAD_optimizer(model, config)

        fold_history = {"train_loss"  : [],
                        "valid_loss"  : [],
                        "val_metrics" : []}

        for epoch in range(num_epochs):
            train_loss = run_training_epoch(model, train_loader, optimizer, loss_funcs, epoch, config)
            fold_history["train_loss"].append(train_loss)

            valid_loss, val_metrics = run_valid_epoch(model, val_loader, loss_funcs, epoch, config)

            fold_history["valid_loss"].append(valid_loss)
            fold_history["val_metrics"].append(val_metrics)

        all_kfold_train_stats.append({"fold": fold + 1, "history": fold_history})

        print("\nK-fold training complete.")

    return all_kfold_train_stats

ImportError: attempted relative import with no known parent package

### Model Valid Module

In [None]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader

from typing import Dict, Tuple, Any

from . import WoMAD_config

def calc_graph_overlap():
    """
    TO DO: Create function to calculate graph overlap for Info-Flow module.
    """
    # Create graph network with the info-flow output.
    # Calculate graph overlap:
    ## Step 1: Normalization (Nodes and weights)
    ## Step 2: Calculate overlap (Edge-wise percentage with thresholding and Jaccard Indx
    ##         and Weighted correlation with vectorization and Pearson's
    ## Step 3: Analyze topological similarity (Compare architecture using key topological metrics)

def calc_dice_coeff():
    """
    TO DO: Create function to calculate Dice coefficient.
    """
    dice_coeff = 0

    return dice_coeff

def calc_r_sqrd():
    """
    TO DO: Create function to calculate R^2 score.
    """
    r_sqrd = 0

    return r_sqrd

def calc_all_metrics():         # NOTE: This is for the information flow module.
    """
    TO DO: Create the function to calculate all metrics (Dice, MSE, R^2, overall score)
    """
    dice_score = calc_dice_coeff()
    mean_sqrd_err = calc_mse()
    r_sqrd = calc_r_sqrd()

    overall_score = (dice_score + r_sqrd) / 2

    metrics_dict = {
        "Dice_coefficient": dice_score,
        "Mean_Sqrd_Error" : mean_sqrd_err,
        "R_squared"       : r_sqrd,
        "Overall_score"   : overall_score
    }

    return metrics_dict

def run_valid_epoch(model, data_loader: DataLoader,
                    loss_funcs: Dict[str, nn.Module],
                    epoch: int, config: dict
                    ) -> Tuple[float, Dict[str, float]]:
    """
    Runs one validation epochs and calculates loss and metrics.
    """
    model.eval()
    total_val_loss = 0
    total_samples = 0

    all_overall_pred = []
    all_overall_target = []
    all_node_pred = []
    all_node_target = []

    overall_weight = WoMAD_config.training_loss_weights["overall_loss_weight"]
    node_weight    = WoMAD_config.training_loss_weights["node_loss_weight"]

    overall_loss_fn = loss_funcs["overall_score_loss"]
    node_loss_fn    = loss_funcs["node_score_loss"]

    with torch.no_grad():
        for data, overall_target, node_target in data_loader:
            overall_target = overall_target.float()
            node_target = node_target.float()

            overall_pred, node_pred = model(data)

            loss_overall = overall_loss_fn(overall_pred.squeeze(), overall_target)
            loss_node    = node_loss_fn(node_pred, node_target)
            combined_loss = (overall_weight  * loss_overall) + (node_weight * loss_node)

            total_val_loss += combined_loss.item() * data.size(0)
            total_samples  += data.size(0)

            all_overall_pred.append(overall_pred)
            all_node_pred.append(node_pred)

            all_overall_target.append(overall_target)
            all_node_target.append(node_target)

    avg_loss = total_val_loss / total_samples

    print(f"Epoch {epoch+1:02d} | Validation Loss: {avg_loss: .6f}")

    final_overall_pred   = torch.cat(all_overall_pred).squeeze()
    final_overall_target = torch.cat(all_overall_target)

    final_node_pred   = torch.cat(all_node_pred)
    final_node_target = torch.cat(all_node_target)

    metrics = calc_all_metrics(final_overall_pred,
                               final_node_pred,
                               final_overall_target,
                               final_node_target)

    print(f"Validation metrics: {metrics}")

    return avg_loss, metrics

### Result Interpretation Module

In [None]:
import numpy as np
import shap
import torch
from typing import Dict, Any

from . import WoMAD_config

def predict():
    """
    TO DO: Create the functions that allows us to use the model for inference.
    """
    return prediction

def visualize_and_interpret(model: WoMAD_core,
                            data_loader: DataLoader,
                            config: dict):
    """
    Generating figures and saliency maps for interpreting the results.
    """
    model.eval()

    # Visualize output (predicted vs. actual score)

    # SHAP analysis based on timeseries and final fused outputs

    # Save visuals
    return 0

def run_pipeline_with_valid_dataset(model: WoMAD_core,
                                    valid_loader: DataLoader,
                                    loss_funcs: Dict[str, nn.Module],
                                    config: dict):
    """
    Running full validation and analysis after training.
    """
    model.eval()

    # Validation loop
    all_preds, all_targets = which_function_is_this(model, valid_loader, loss_funcs, config)

    # Metrics
    metrics = calc_all_metrics(all_preds["overall"],
                               all_preds["node"],
                               all_targets["overall"],
                               all_targets["node"])
    print(f"Final validation metrics: {metrics}")

    # Visuals
    visualize_and_interpret(model, valid_loader, config)

    print("Validation complete.")

    # Save final metrics
    # print(f"Outputs saved to: {output_dir}")

    return metrics

## TESTS

In [None]:
# TODO: Create dummy data OR sample a tiny subset of the actual dataset
# TODO: Create tests for each and every single function or method.

# TESTS:
def test_data_module():

def test_model_setup():

def test_model_training():

def test_model_validation():

def test_result_interpretation():

## WoMAD Main

In [None]:
# Terminal functions and UI
## print status, success, or error
## clear screen
## Welcome/Completion

def run_WoMAD(config):
    # Environment setup
    # Data and initial processing
    # Model setup
    # Training (and hyperparameter search)
    # Post-training: Analysis, Visualization, and Interpretation

if __name__ == "__main__":
    # config = WoMAD_config.load_config()
    run_WoMAD(config)