In [6]:
# Random Forest (RF) for Node-Level classification - for left and right hemisphere nodes
import os
from typing import Any, Union
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from scipy.io import loadmat
from sklearn.neural_network import MLPClassifier as MLP
from sklearn.ensemble import RandomForestClassifier as RF
from sklearn.svm import SVC as SVM

from sklearn.metrics import balanced_accuracy_score
from matplotlib import pyplot as plt


import warnings
warnings.filterwarnings("ignore")

In [7]:
def calculate_metrics(y_pred, y_true):

    return balanced_accuracy_score(y_true, y_pred)

In [8]:
# Root Folder
# root='/home/neil/Lab_work/Jeong_Lab_Multi_Modal_MRI/Left_Hemis/Part_2/'
# root='/home/neil/Lab_work/Jeong_Lab_Multi_Modal_MRI/Right_Hemis/Part_2/'

root='/media/user1/MyHDataStor41/Soumyanil_EZ_Pred_project/Data/All_Hemispheres/Right_Hemis/Part_2/'

In [9]:
# Left/Right Hemis Nodes
# def get_list_of_node_nums():
#     node_numbers_with_smote = [
#     "504", "506", "508", "509", "510", "511", "512", "513", "514", "515", "516", "517", "518", "519", "520", "521", "522", "524", "525", "526", "529", "530", "534", "535", "536", "537", "538", "539", "540", "541", "542", "543", "546", "547", "548", "549", "551", "552", "553", "554", "555", "556", "557", "558", "559", "560", "561", "562", "563", "564", "565", "566", "567", "568", "569", "570", "571", "572", "573", "574", "575", "576", "581", "582", "584", "585", "586", "587", "588", "589", "590", "591", "592", "593", "594", "595", "596", "598", "599", "600", "601", "602", "603", "604", "605", "606", "607", "608", "609", "610", "612", "613", "614", "615", "616", "617", "618", "619", "620", "621", "622", "623", "624", "625", "627", "628", "629", "630", "632", "633", "634", "635", "636", "637", "638", "639", "640", "641", "642", "643", "644", "645", "646", "647", "648", "649", "650", "651", "652", "655", "656", "657", "658", "659", "660", "661", "662", "663", "664", "665", "666", "668", "669", "670", "671", "672", "673", "674", "675", "676", "677", "678", "681", "683", "685", "686", "690", "691", "692", "693", "694", "695", "696", "697", "698", "699", "700", "701", "702", "703", "704", "705", "706", "707", "708", "709", "710", "711", "712", "713", "714", "715", "716", "717", "718", "719", "720", "721", "722", "723", "724", "725", "726", "727", "728", "730", "731", "732", "733", "735", "736", "737", "738", "739", "740", "741", "742", "743", "744", "745", "746", "747", "748", "749", "750", "751", "756", "757", "758", "759", "760", "761", "762", "763", "764", "765", "766", "767", "770", "771", "776", "777", "778", "779", "780", "781", "782", "783", "784", "785", "786", "787", "788", "789", "790", "791", "792", "793", "795", "796", "797", "798", "799", "800", "801", "802", "803", "804", "805", "806", "808", "809", "810", "811", "812", "813", "816", "817", "818", "819", "820", "821", "822", "823", "824", "825", "826", "827", "828", "829", "830", "831", "832", "834", "835", "836", "837", "838", "839", "841", "842", "843", "844", "845", "846", "847", "848", "849", "850", "851", "852", "853", "854", "855", "856", "857", "858", "859", "860", "861", "862", "863", "864", "865", "866", "867", "868", "869", "870", "871", "872", "873", "874", "875", "877", "878", "879", "880", "881", "882", "883", "885", "886", "887", "888", "889", "890", "891", "892", "893", "894", "895", "896", "897", "898", "899", "900", "901", "902", "903", "904", "905", "906", "907", "908", "909", "910", "911", "912", "913", "914", "915", "916", "917", "918", "919", "920", "921", "922", "923", "924", "925", "926", "927", "928", "929", "930", "931", "932", "933", "934", "935", "937", "938", "939", "940", "941", "942", "943", "944", "945", "946", "947", "948", "949", "950", "951", "952", "953", "954", "955", "956", "957", "958", "960", "961", "962", "963", "964", "965", "968", "969", "970", "971", "973", "974", "975", "976", "977", "978", "979", "980", "981", "982", "983"
#     ]

#     return node_numbers_with_smote



def get_list_of_node_nums():
    node_numbers_with_smote = [
    "504", "506", "508", "509", "510"
    ]

    return node_numbers_with_smote

In [10]:
right_hemis_nodes = get_list_of_node_nums()
print(len(right_hemis_nodes))

5


In [11]:
def get_modality(num: int, X: np.ndarray, fill_zeros: bool = False) -> np.ndarray:
    if num == 1:
        dict_mod = {4:"DWIC"}
        list_mod = [0,0,0,0,1]
        if fill_zeros:
            X_modality = X[:, 1400:]
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = X[:, 1400:]
    elif num == 2:
        dict_mod = {3:"DWI"}
        list_mod = [0,0,0,1,0]
        if fill_zeros:
            X_modality = X[:, 700:1400]
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = X[:, 700:1400]
    elif num == 3:
        dict_mod = {3:"DWI", 4:"DWIC"}
        list_mod = [0,0,0,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 700:1400], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 700:1400], X[:, 1400:]), axis=1)
    elif num == 4:
        dict_mod = {2:"FLAIR"}
        list_mod = [0,0,1,0,0]
        if fill_zeros:
            X_modality = X[:, 500:700]
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = X[:, 500:700]
    elif num == 5:
        dict_mod = {2:"FLAIR", 4:"DWIC"}
        list_mod = [0,0,1,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 500:700], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 500:700], X[:, 1400:]), axis=1)
    elif num == 6:
        dict_mod = {2:"FLAIR", 3:"DWI"}
        list_mod = [0,0,1,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 500:700], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 500:700], X[:, 700:1400]), axis=1)
    elif num == 7:
        dict_mod = {2:"FLAIR", 3:"DWI", 4:"DWIC"}
        list_mod = [0,0,1,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 500:700], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 500:700], X[:, 700:]), axis=1)
    elif num == 8:
        dict_mod = {1:"T2"}
        list_mod = [0,1,0,0,0]
        if fill_zeros:
            X_modality = X[:, 300:500]
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = X[:, 300:500]
    elif num == 9:
        dict_mod = {1:"T2", 4:"DWIC"}
        list_mod = [0,1,0,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 1400:]), axis=1)
    elif num == 10:
        dict_mod = {1:"T2", 3:"DWI"}
        list_mod = [0,1,0,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 700:1400]), axis=1)
    elif num == 11:
        dict_mod = {1:"T2", 3:"DWI", 4:"DWIC"}
        list_mod = [0,1,0,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 700:]), axis=1)
    elif num == 12:
        dict_mod = {1:"T2", 2:"FLAIR"}
        list_mod = [0,1,1,0,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700]), axis=1)
    elif num == 13:
        dict_mod = {1:"T2", 2:"FLAIR", 4:"DWIC"}
        list_mod = [0,1,1,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1) 
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 1400:]), axis=1)
    elif num == 14:
        dict_mod = {1:"T2", 2:"FLAIR", 3:"DWI"}
        list_mod = [0,1,1,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1) 
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 700:1400]), axis=1)
    elif num == 15:
        dict_mod = {1:"T2", 2:"FLAIR", 3:"DWI", 4:"DWIC"}
        list_mod = [0,1,1,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1) 
        else:
            X_modality = np.concatenate((X[:, 300:500], X[:, 500:700], X[:, 700:]), axis=1)
    elif num == 16:
        dict_mod = {0:"T1"}
        list_mod = [1,0,0,0,0]
        if fill_zeros:
            X_modality = X[:, :300]
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = X[:, :300]
    elif num == 17:
        dict_mod = {0:"T1", 4:"DWIC"}
        list_mod = [1,0,0,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 1400:]), axis=1)
    elif num == 18:
        dict_mod = {0:"T1", 3:"DWI"}
        list_mod = [1,0,0,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 700:1400]), axis=1)
    elif num == 19:
        dict_mod = {0:"T1", 3:"DWI", 4:"DWIC"}
        list_mod = [1,0,0,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 700:]), axis=1)
    elif num == 20:
        dict_mod = {0:"T1", 2:"FLAIR"}
        list_mod = [1,0,1,0,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700]), axis=1)
    elif num == 21:
        dict_mod = {0:"T1", 2:"FLAIR", 4:"DWIC"}
        list_mod = [1,0,1,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 1400:]), axis=1)
    elif num == 22:
        dict_mod = {0:"T1", 2:"FLAIR", 3:"DWI"}
        list_mod = [1,0,1,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))), axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 700:1400]), axis=1)
    elif num == 23:
        dict_mod = {0:"T1", 2:"FLAIR", 3:"DWI", 4:"DWIC"}
        list_mod = [1,0,1,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 500:700], X[:, 700:]), axis=1)
    elif num == 24:
        dict_mod = {0:"T1", 1:"T2"}
        list_mod = [1,1,0,0,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500]), axis=1)
    elif num == 25:
        dict_mod = {0:"T1", 1:"T2", 4:"DWIC"}
        list_mod = [1,1,0,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 1400:]), axis=1)
    elif num == 26:
        dict_mod = {0:"T1", 1:"T2", 3:"DWI"}
        list_mod = [1,1,0,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 700:1400]), axis=1)
    elif num == 27:
        dict_mod = {0:"T1", 1:"T2", 3:"DWI", 4:"DWIC"}
        list_mod = [1,1,0,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 700:]), axis=1)
    elif num == 28:
        dict_mod = {0:"T1", 1:"T2", 2:"FLAIR"}
        list_mod = [1,1,1,0,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700]), axis=1)
    elif num == 29:
        dict_mod = {0:"T1", 1:"T2", 2:"FLAIR", 4:"DWIC"}
        list_mod = [1,1,1,0,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 1400:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:        
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 1400:]), axis=1)
    elif num == 30:
        dict_mod = {0:"T1", 1:"T2", 2:"FLAIR", 3:"DWI"}
        list_mod = [1,1,1,1,0]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 700:1400]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:        
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 700:1400]), axis=1)
    elif num == 31:
        dict_mod = {0:"T1", 1:"T2", 2:"FLAIR", 3:"DWI", 4:"DWIC"}
        list_mod = [1,1,1,1,1]
        if fill_zeros:
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 700:]), axis=1)
            X_modality = np.concatenate((X_modality, np.zeros((X.shape[0], (X.shape[1]-X_modality.shape[1])))),axis=1)
        else:        
            X_modality = np.concatenate((X[:, :300], X[:, 300:500], X[:, 500:700], X[:, 700:]), axis=1)
    else:
        raise ValueError(f"num should be betwen 1 and 31, got {num}")

    return X_modality, dict_mod # type: ignore

In [12]:
def load_train_data(root: str, node_num: str, j: int):

    train_path = os.path.join(root, 'Node_'+node_num, 'Aug_Train_Data', 'ALL_Patients')  
    x_file = f"X_train_aug"
    y_file = f"Y_train_aug"
    x_mat_name = "X_aug_train"
    y_mat_name = "Y_aug_train"  

    raw_path_x = os.path.join(train_path, f"{x_file}.mat")
    raw_path_y = os.path.join(train_path, f"{y_file}.mat")

    # Load the data from .mat files
    X_mat_l = loadmat(raw_path_x)
    X_mat = X_mat_l[x_mat_name]

    X_mat_modality, dict_mod = get_modality(j, X_mat) # get X for the modality

    Y_mat_l = loadmat(raw_path_y)
    Y_mat = Y_mat_l[y_mat_name]
    Y_mat = Y_mat.reshape(Y_mat.shape[1],)

    # # Count and print the number of 0s and 1s
    # num_zeros = np.sum(Y_mat == 0)
    # num_ones = np.sum(Y_mat == 1)
    # print(f"Train data for Node {node_num}, Modality {j}: Class 0 count = {num_zeros}, Class 1 count = {num_ones}")

    return X_mat_modality, Y_mat, dict_mod
    # return X_mat, Y_mat

In [13]:
def load_test_data(root: str, node_num: str, j: int):

    val_path = os.path.join(root, 'Node_'+node_num, 'Orig_Val_Data', 'ALL_Patients')  
    x_file = f"X_valid_orig"
    y_file = f"Y_valid_orig"
    x_mat_name = "X_orig_valid"
    y_mat_name = "Y_orig_valid"  

    raw_path_x = os.path.join(val_path, f"{x_file}.mat")
    raw_path_y = os.path.join(val_path, f"{y_file}.mat")

    # Load the data from .mat files
    X_mat_l = loadmat(raw_path_x)
    X_mat = X_mat_l[x_mat_name]

    X_mat_modality, dict_mod = get_modality(j, X_mat) # get X for the modality
    # X_mat_modality, dict_mod = get_modality(j, X_mat, fill_zeros=True) # get X for the modality

    Y_mat_l = loadmat(raw_path_y)
    Y_mat = Y_mat_l[y_mat_name]
    Y_mat = Y_mat.reshape(Y_mat.shape[1],)

    # # Count and print the number of 0s and 1s
    # num_zeros = np.sum(Y_mat == 0)
    # num_ones = np.sum(Y_mat == 1)
    # print(f"Test data for Node {node_num}, Modality {j}: Class 0 count = {num_zeros}, Class 1 count = {num_ones}")

    return X_mat_modality, Y_mat, dict_mod

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class MRIModalityBranch(nn.Module):
    """
    1D CNN branch for each MRI modality with adaptive architecture based on input size
    """
    def __init__(self, input_size, dropout_rate=0.5):
        super(MRIModalityBranch, self).__init__()
        
        # Determine the number of conv layers based on input size
        self.layers = nn.ModuleList()
        
        # First conv block
        self.layers.append(nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.MaxPool1d(2)
        ))
        
        # Track size after each layer
        current_size = input_size // 2
        
        # Second conv block
        if current_size > 50:
            self.layers.append(nn.Sequential(
                nn.Conv1d(16, 32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.BatchNorm1d(32),
                nn.MaxPool1d(2)
            ))
            current_size = current_size // 2
        
        # Dropout
        self.layers.append(nn.Dropout(dropout_rate))
        
        # Third conv block
        if current_size > 25:
            self.layers.append(nn.Sequential(
                nn.Conv1d(32, 64, kernel_size=3, padding=1),
                nn.MaxPool1d(2)
            ))
            current_size = current_size // 2
            self.final_channels = 64
        else:
            self.final_channels = 32
        
        # Calculate the output feature size
        self.output_size = current_size * self.final_channels
        
    def forward(self, x):
        # Input shape: (batch_size, features)
        x = x.unsqueeze(1)  # Add channel dimension: (batch_size, 1, features)
        
        for layer in self.layers:
            x = layer(x)
        
        # Flatten the output
        x = x.view(x.size(0), -1)
        return x

class HierarchicalMRIModel(nn.Module):
    """
    Hierarchical 1D CNN model for multiple MRI modalities
    """
    def __init__(self, modality_lengths, dropout_rate=0.5):
        super(HierarchicalMRIModel, self).__init__()
        
        # Create a branch for each modality
        self.branches = nn.ModuleDict()
        self.modality_lengths = modality_lengths
        
        total_features = 0
        for modality_name, length in modality_lengths.items():
            if length > 0:
                self.branches[modality_name] = MRIModalityBranch(length, dropout_rate)
                total_features += self.branches[modality_name].output_size
        
        # MLP classifier after concatenation
        self.classifier = nn.Sequential(
            nn.Linear(total_features, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # Binary classification
        )
    
    def forward(self, x_dict):
        # Process each modality through its branch
        features = []
        for modality_name, branch in self.branches.items():
            if self.modality_lengths[modality_name] > 0 and modality_name in x_dict:
                features.append(branch(x_dict[modality_name]))
        
        # Concatenate all features
        if len(features) > 0:
            combined_features = torch.cat(features, dim=1)
            
            # Apply classifier
            output = self.classifier(combined_features)
            return output
        else:
            raise ValueError("No valid modalities provided")

In [16]:
def get_modality_data(num, X):
    """
    Extract data for specific modality combination
    Returns a dictionary with individual modalities
    """
    modality_dict = {}
    modality_lengths = {
        "T1": 0, "T2": 0, "FLAIR": 0, "DWI": 0, "DWIC": 0
    }
    
    # Based on the existing get_modality function logic
    if num & 1:  # T1
        modality_dict["T1"] = X[:, :300]
        modality_lengths["T1"] = 300
    
    if num & 2:  # T2
        modality_dict["T2"] = X[:, 300:500]
        modality_lengths["T2"] = 200
    
    if num & 4:  # FLAIR
        modality_dict["FLAIR"] = X[:, 500:700]
        modality_lengths["FLAIR"] = 200
    
    if num & 8:  # DWI
        modality_dict["DWI"] = X[:, 700:1400]
        modality_lengths["DWI"] = 700
    
    if num & 16:  # DWIC
        modality_dict["DWIC"] = X[:, 1400:]
        modality_lengths["DWIC"] = X.shape[1] - 1400
    
    return modality_dict, modality_lengths

In [19]:
def evaluate(model, data_loader, device):
    """
    Evaluate the model's performance
    """
    model.eval()
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        for batch in data_loader:
            # Extract batch data
            x_dict = {modality: batch[i].to(device) for i, modality in enumerate(batch[:-1])}
            y = batch[-1].numpy()
            
            # Forward pass
            outputs = model(x_dict)
            _, predicted = torch.max(outputs, 1)
            
            y_true_all.extend(y)
            y_pred_all.extend(predicted.cpu().numpy())
    
    # Calculate metrics
    bal_acc = calculate_metrics(y_pred_all, y_true_all)
    return bal_acc

In [17]:
def train_model(model, train_loader, val_loader, device, epochs=30, patience=5):
    """
    Train the model with early stopping
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    best_val_acc = 0
    patience_counter = 0
    
    train_losses = []
    val_accs = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            # Extract batch data
            x_dict = {modality: batch[i].to(device) for i, modality in enumerate(batch[:-1])}
            y = batch[-1].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(x_dict)
            loss = criterion(outputs, y)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        # Validation
        val_acc = evaluate(model, val_loader, device)
        val_accs.append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Save best model
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                model.load_state_dict(best_model_state)
                break
    
    return model, train_losses, val_accs, best_val_acc

In [20]:
def prepare_data_loaders(X_modality_dict, Y, batch_size=32):
    """
    Prepare data loaders for model training
    """
    # Convert numpy arrays to PyTorch tensors
    tensor_dict = {}
    for modality, data in X_modality_dict.items():
        tensor_dict[modality] = torch.FloatTensor(data)
    
    y_tensor = torch.LongTensor(Y)
    
    # Create a list of tensors in a fixed order + label
    tensor_list = list(tensor_dict.values()) + [y_tensor]
    
    # Create dataset and dataloader
    dataset = TensorDataset(*tensor_list)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return data_loader, list(tensor_dict.keys())

In [21]:
# Main execution
def main():
    root='/media/user1/MyHDataStor41/Soumyanil_EZ_Pred_project/Data/All_Hemispheres/Right_Hemis/Part_2/'
    node_numbers = get_list_of_node_nums()
    
    # Check for GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Results storage
    results = []
    
    for node_num in node_numbers:
        print(f'Processing Node: {node_num}')
        
        node_results = []
        
        # Train and test ALL modality combinations
        for j in range(1, 32):
            print(f'\nModality Combination: {j}')
            
            # Load data
            X_train_dict, modality_lengths, Y_train = load_train_data(root, node_num, j)
            X_test_dict, _, Y_test = load_test_data(root, node_num, j)
            
            # Create data loaders
            train_loader, modality_names = prepare_data_loaders(X_train_dict, Y_train)
            test_loader, _ = prepare_data_loaders(X_test_dict, Y_test, batch_size=64)
            
            # Create model
            model = HierarchicalMRIModel(modality_lengths).to(device)
            print(f"Model created with modalities: {list(X_train_dict.keys())}")
            
            # Train model
            model, train_losses, val_accs, best_val_acc = train_model(
                model, train_loader, test_loader, device
            )
            
            # Final evaluation
            final_acc = evaluate(model, test_loader, device)
            print(f"Final balanced accuracy: {final_acc:.4f}")
            
            # Store results
            node_results.append({
                'node': node_num,
                'modality_combo': j,
                'modalities': list(X_train_dict.keys()),
                'accuracy': final_acc
            })
        
        # Save node results
        results.extend(node_results)
        
        # Create a DataFrame for this node
        df_node = pd.DataFrame(
            [r['accuracy'] for r in node_results], 
            columns=[f'Node_{node_num}']
        )
        
        # Save to Excel
        path = f"/media/user1/MyHDataStor41/Soumyanil_EZ_Pred_project/magmsforEZprediction/hierarchical_cnn_results/"
        save_path = os.path.join(path, f"Node_{node_num}", "Eval_Results")
        
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        filename = "hierarchical_cnn_results_val_ALL_modality_combination.xlsx"
        save_filepath = os.path.join(save_path, filename)
        
        df_node.to_excel(save_filepath, index=False, sheet_name='Sheet1')
    
    # Save combined results
    df_all = pd.DataFrame(results)
    df_all.to_excel(os.path.join(path, "hierarchical_cnn_ALL_nodes_combined.xlsx"), index=False)


In [None]:
# Main loop to run the baseline models over all the nodes (for all 3 trials)

node_numbers_with_smote = get_list_of_node_nums()

for node_num in node_numbers_with_smote:
    
    print(f'Node num: {node_num}')

    num_trials = 3
    
    val_bal_acc_per_modality_list = []

    # Train and test ALL modality combinations for ALL trials
    for j in range(1,32): 
        # load the data for the given node and given modality combination
        X_train, Y_train, dict_mod = load_train_data(root, node_num, j)
        # X_train, Y_train = load_train_data(root, node_num, j)

        # print(f"X_train shape: {X_train.shape}")
        
        X_test, Y_test, dict_mod = load_test_data(root, node_num, j)

        # print(f"X_test shape: {X_test.shape}")
        # raise ValueError("Stop here")

        print(f"Modality Combination: {dict_mod}")

        # Define the model
        if model == 'MLP':
            clf = MLP(hidden_layer_sizes=(256,), learning_rate_init=0.01, random_state=None, max_iter=1000, early_stopping=False)
        elif model == 'RF':
            clf = RF(n_estimators=100, random_state=None,)
        elif model == 'SVM':
            clf = SVM(C=1.0, kernel='rbf', max_iter=-1, random_state=None)
        else:
            raise NotImplementedError("Unknown Model.")

        val_bal_acc_list = []
        # Run 5 trials for each node
        for i in range(num_trials):
            print(f'Training Trial {i+1} of Node number {node_num}')

            # Train the model
            clf.fit(X_train, Y_train)

            # Test the model
            print(f'Evaluating Trial {i+1} of Node number: {node_num}')
            y_true = Y_test
            y_pred = clf.predict(X_test)

            # Evaluate Trained Model with evaluation metrics
            bal_acc = calculate_metrics(y_pred, y_true)  
            # print(f"Balanced Accuracy: {bal_acc}")

            val_bal_acc_list.append(bal_acc) 

        val_bal_acc_per_modality_list.append(np.mean(val_bal_acc_list))

    # Create a DataFrame
    headers_val = ['Node_'+node_num]

    df_val = pd.DataFrame(val_bal_acc_per_modality_list, columns=headers_val)

    # Saving to Excel
    # path = "/home/neil/Lab_work/Jeong_Lab_Multi_Modal_MRI/magmsforEZprediction/baselines_node_level/" 
    path = f"/media/user1/MyHDataStor41/Soumyanil_EZ_Pred_project/magmsforEZprediction/sota_node_level/{model}_Results/" 
    save_path = os.path.join(path, "Node_"+node_num, "Eval_Results")

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    filename_val = "RF_results_val_ALL_modality_combination_test.xlsx"
    save_filepath_val = os.path.join(save_path, filename_val)

    df_val.to_excel(save_filepath_val, index=False, sheet_name='Sheet1')


Node num: 504
Train data for Node 504, Modality 1: Class 0 count = 60, Class 1 count = 60
X_train shape: (120, 499)
Test data for Node 504, Modality 1: Class 0 count = 10, Class 1 count = 0
X_test shape: (10, 499)


ValueError: Stop here

In [12]:
# Combine all node results into one dataframe

# Define the paths of your Excel files
base_path = "/home/neil/Lab_work/Jeong_Lab_Multi_Modal_MRI/magmsforEZprediction/baselines_node_level/" 

# For FULL modality Only
node_nums = get_list_of_node_nums()

file_paths_val = []

for node_num in node_nums:
    file_path_val = os.path.join(base_path, "Node_"+node_num+"_Results", "Eval_Results", "RF_results_val_ALL_modality_combination_test.xlsx") # For FULL modality Only
    file_paths_val.append(file_path_val)

# Initialize an empty DataFrame
combined_df_val = pd.DataFrame()

# Loop through the files and stack the rows
for path in file_paths_val:
    # Load the Excel file
    df = pd.read_excel(path)  

    # Stack the rows
    combined_df_val = pd.concat([combined_df_val, df], axis=1) # For ALL modality combinations

# Reset the index to avoid duplicate row indices
combined_df_val = combined_df_val.reset_index(drop=True)

# Save the combined DataFrame to a new Excel file
# combined_df_val.to_excel('RF_results_val_ALL_modality_combination_combined_Right_Hemis.xlsx', index=False)
combined_df_val.to_excel('RF_results_val_ALL_modality_combination_combined_Right_Hemis_test.xlsx', index=False)
