In [1]:
import os
import numpy as np
import scipy.io
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from utils import *  # NeuroGraph
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix
from sklearn.model_selection import KFold


## Helper function for generating multiple sets of arguments from one object ##

In [2]:
# Helper function from one dictionary to multiple sets of params
def grid_from_param(dic): #, append_model_name = True, model_name_default = 'default'):
    # retrieve all lists to choose parameters from
    names = []
    lens = []
    for name in dic:
        item = dic[name]
        if(type(item)==list):
            names.append(name)
            lens.append(len(item))

    #helper
    count = 1
    mults = []
    for l in lens:
        mults.append(count)
        count *= l
    #print(mults)
    #print(count)
    #print(lens)

    #construct the grid
    params = []
    n = len(lens)
    for i in range(count):
        param = dic.copy()
        param['tune_name'] = '_'
        for j in range(n):
            param[names[j]] = dic[names[j]][i//mults[j] % lens[j]]
            param['tune_name'] += f"{names[j]}{param[names[j]]}_"
            # if param['tune_name'] contains '/', replace with '_'
            param['tune_name'] = param['tune_name'].replace('/', '_')
        params.append(param)
    return params

## Hyperparameters setting

In [3]:

class Args: # now it's just a wrapper for compatibility. Everything now packed up in the dictionary.
    def __init__(self, param_dict) -> None:
        # wrapped. see argDict above.
        self.dataset = param_dict['dataset']
        self.dataset_dir = param_dict['dataset_dir']
        self.edge_dir_prefix = param_dict['edge_dir_prefix']
        self.model = param_dict['model']
        self.num_classes = param_dict['num_classes']
        self.weight_decay = param_dict['weight_decay']
        self.batch_size = param_dict['batch_size']
        self.hidden_mlp = param_dict['hidden_mlp']
        self.hidden = param_dict['hidden']
        self.num_layers = param_dict['num_layers']
        self.runs = param_dict['runs']
        self.lr = param_dict['lr']
        self.epochs = param_dict['epochs']
        self.edge_percent_to_keep = param_dict['edge_percent_to_keep'] 
        self.seed = param_dict['seed']
        self.n_splits = param_dict['n_splits'] if "n_splits" in param_dict else 5
        self.device = "cpu" if self.model != "GATConv" else "cpu"
        self.tune_name = param_dict['tune_name'] if "tune_name" in param_dict else None
    def tuning_list(param_dicts : dict):
        p = grid_from_param(param_dicts)
        return [Args(x) for x in p]


# #print([x['lr'] for x in grid_from_param(argsdict)])
# args_list = Args.tuning_list(argsDictTune)
# fix_seed(args_list[0].seed)

# #print(len(Args.tuning_list(argsDictTune)))



## Reading our Datasets
use our HCP correlation matrix dataset, train/test split file, label file.

HCP data is downloaded from https://drive.google.com/drive/folders/166wCCtPOEL0O25FxzwB0I8AQA8b6Q9U1?usp=drive_link 

other files are in the data folder

use our ADNI dataset

In [4]:
def read_adni_data(args):
    fMRI_path = args.dataset_dir + "fmri_signal.mat"
    ICV_path = args.dataset_dir + "ICV.mat"
    AGE_path = args.dataset_dir + "AGE.mat"
    DX_path = args.dataset_dir + "DX.mat"
    gender_path = args.dataset_dir + "gender.mat"
    fMRI_data_path = args.dataset_dir + "fMRIdata_ADNI2_ADNI3.csv"
    # participants_path = r'./data/ADNI/participants.tsv'

    # read fMRI_path
    fmri_data = scipy.io.loadmat(fMRI_path)['fmri_signal']
    fMRI_data = [fmri_data[i][0] for i in range(len(fmri_data))]

    # read ICV_path
    icv_data = scipy.io.loadmat(ICV_path)['ICV']
    ICV_data = pd.DataFrame([icv_data[i][0] for i in range(len(icv_data))])

    # read AGE_path
    age_data = scipy.io.loadmat(AGE_path)['AGE']
    AGE_data = pd.DataFrame([age_data[i][0] for i in range(len(age_data))])

    # read gender_path
    gender_data = scipy.io.loadmat(gender_path)['gender']
    gender_data = pd.DataFrame([gender_data[i][0] for i in range(len(gender_data))])

    # read DX_path
    dx_data = scipy.io.loadmat(DX_path)['DX']
    DX_data = pd.DataFrame([dx_data[i][0] for i in range(len(dx_data))])

    # for all above variable, add a df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data))) to add Image_ID column
    for df in [ICV_data, AGE_data, gender_data, DX_data]:
        df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data)))

    # give their column names, EstimatedTotalIntraCranialVol, Age, Gender, Diagnosis
    ICV_data.columns = ['Image_ID', 'EstimatedTotalIntraCranialVol']
    AGE_data.columns = ['Image_ID', 'Age']
    gender_data.columns = ['Image_ID', 'Gender']
    DX_data.columns = ['Image_ID', 'Diagnosis']
    Image_ID = ICV_data['Image_ID']

    data_dict = {
        'fMRI_data': fMRI_data,
        'ICV_data': ICV_data,
        'AGE_data': AGE_data,
        'gender_data': gender_data,
        'DX_data': DX_data
    }
    return data_dict

In [5]:
def load_data_from_args(args: Args):

    # Label path
    labels_file = args.dataset_dir + 'y.csv'
    # Load labels
    labels_df = pd.read_csv(labels_file)

    # for ADNI Dataset
    if args.dataset == "ADNI":
        adni_data = read_adni_data(args)
        fMRI_data = adni_data['fMRI_data']
        ICV_data = adni_data['ICV_data']
        AGE_data = adni_data['AGE_data']
        gender_data = adni_data['gender_data']
        DX_data = adni_data['DX_data']

        # only keep healthy control and AD. namely 2 and 0
        labels_df = labels_df[labels_df['Diagnosis'].isin([2, 0])].reset_index(drop=True)
        # change all 2 to 1
        labels_df['Diagnosis'] = labels_df['Diagnosis'].replace({2: 1})

        dataset = []
        # traverse the labels_df by i
        for i in range(len(labels_df)):
            IID = labels_df['IID'][i]
            y = labels_df['Diagnosis'][i]
            # turn y to <class 'torch.Tensor'>
            y = torch.tensor(y, dtype=torch.long)
            # z-score normalization for each column of each subject
            subject_data = fMRI_data[IID]
            # fill 0 with 1
            subject_data[subject_data == 0] = 1
            subject_data = (subject_data - np.mean(subject_data, axis=0)) / np.std(subject_data, axis=0)

            try:
                edge_attr = pd.read_csv(args.dataset_dir + 'fmri_edge/' + args.edge_dir_prefix + str(IID) + '.csv')
            except:
                print('File \"' + args.dataset_dir + 'fmri_edge/' + args.edge_dir_prefix + str(IID) + '.csv\" not found. Skipping.')
                continue
            edge_attr = edge_attr.to_numpy()
            np.fill_diagonal(edge_attr, 0)

            # get the threshold of edge_attr
            threshold = np.percentile(edge_attr, 100 * (1 - args.edge_percent_to_keep))

            # only keep edges that are larger than the threshold
            edge_attr[edge_attr <= threshold] = 0

            total_edge_count = edge_attr.shape[0] * edge_attr.shape[1]
            target_edge_count = int(args.edge_percent_to_keep * total_edge_count)
            edge_index = np.vstack(np.nonzero(edge_attr))

            filtered_edge_attr = edge_attr[edge_index[0], edge_index[1]]
            filtered_edge_attr = torch.tensor(filtered_edge_attr, dtype=torch.float)
            current_edge_count = filtered_edge_attr.shape[0]

            # Adjust the number of edges without randomness
            # Adjust the number of edges without randomness
            if current_edge_count > target_edge_count:
                # Sort edges by their weights in descending order and keep the top edges
                sorted_indices = torch.argsort(filtered_edge_attr, descending=True)
                indices_to_keep = sorted_indices[:target_edge_count]
                edge_index = edge_index[:, indices_to_keep]
                filtered_edge_attr = filtered_edge_attr[indices_to_keep]
            elif current_edge_count < target_edge_count:
                # Sort edges by their weights in ascending order and add the smallest edges until the target is met
                sorted_indices = torch.argsort(filtered_edge_attr, descending=False)
                indices_to_add = sorted_indices[:target_edge_count - current_edge_count]

                # Convert indices_to_add to a NumPy array
                indices_to_add = indices_to_add.cpu().numpy()

                # Index edge_index with indices_to_add
                edge_index_to_add = edge_index[:, indices_to_add]

                # Ensure edge_index_to_add has the correct shape
                if edge_index_to_add.ndim == 1:
                    edge_index_to_add = edge_index_to_add.reshape(2, 1)

                # Stack the edge indices
                edge_index = np.hstack([edge_index, edge_index_to_add])

                # Similarly, handle filtered_edge_attr
                filtered_edge_attr_to_add = filtered_edge_attr[indices_to_add]
                if filtered_edge_attr_to_add.ndim == 0:
                    filtered_edge_attr_to_add = filtered_edge_attr_to_add.unsqueeze(0)

                filtered_edge_attr = torch.cat([filtered_edge_attr, filtered_edge_attr_to_add])


            # Create the Data object
            data = Data(x=torch.tensor(edge_attr, dtype=torch.float), 
                        edge_index=torch.tensor(edge_index, dtype=torch.long), 
                        edge_attr=filtered_edge_attr, 
                        y=y)

            # Append the processed data
            dataset.append(data)
        return dataset

    # for HCP
    elif args.dataset == "HCP":
        path = "/home/songlinzhao/task-driven-parcellation/baseline/data/HCP/HCPGender.pt"
        # Load the data
        data = torch.load(path)
        pass


## Train/Test Functions ##

In [6]:
def train(model, args: Args, train_loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.train()
    total_loss = 0
    for data in train_loader:  
        data = data.to(args.device)
        out = model(data) 
        loss = criterion(out, data.y) 
        total_loss +=loss
        loss.backward()
        optimizer.step() 
        optimizer.zero_grad()
    return total_loss/len(train_loader.dataset)

@torch.no_grad()
def test(model, args: Args, loader):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            out = model(data)
            probs = F.softmax(out, dim=1)  # Calculate probabilities
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy()[:, 1])  # Keep the probabilities of the positive class
            all_labels.append(data.y.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    auroc = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    
    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    
    metrics = {
        'accuracy': accuracy,
        'auroc': auroc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'f1_score': f1
    }
    
    return metrics

# test for multiclass
def test_multiclass(model, args: Args, loader):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            out = model(data)
            probs = F.softmax(out, dim=1)  # calculate probabilities for each class
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels)
    
    # metrics
    accuracy = accuracy_score(all_labels, all_preds)
    # here is how to calculate auroc for multiclass
    auroc = roc_auc_score(all_labels, all_probs, multi_class='ovr')  # ovr should be used for multiclass
    f1 = f1_score(all_labels, all_preds, average='weighted')  # weighted should be used for multiclass
    # confusion_matrix need to be calculated for each class
    cm = confusion_matrix(all_labels, all_preds)
    
    # sensitivity and specificity for each class
    sensitivities = []
    specificities = []
    for i in range(cm.shape[0]):
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - (tp + fn + fp)
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        sensitivities.append(sensitivity)
        specificities.append(specificity)
    
    metrics = {
        'accuracy': accuracy,
        'auroc': auroc,
        'sensitivity': np.mean(sensitivities),
        'specificity': np.mean(specificities),
        'f1_score': f1
    }
    
    return metrics

def bench_from_args(args: Args, verbose = False):
    # get train and test data
    # train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
    # train_data, test_data = load_data_from_args(args)
    full_data = load_data_from_args(args)
    # print('full_data:', full_data)
    # Initialize KFold
    kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
    
    fold_metrics = []

    for fold, (train_idx, test_idx) in enumerate(kf.split(full_data)):
        print(f"Fold {fold + 1}/{args.n_splits}")

        # Create train and validation data loaders for this fold
        train_data = [full_data[i] for i in train_idx]
        test_data = [full_data[i] for i in test_idx]

        train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=args.seed)

        # create data loaders
        train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
        val_loader = DataLoader(val_data, args.batch_size, shuffle=False)
        test_loader = DataLoader(test_data, args.batch_size, shuffle=False)

        checkpoints_dir = './checkpoints/'
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        val_acc_history, test_acc_history, test_loss_history = [],[],[]
        #seed = 42
        for index in range(args.runs):
            gnn = eval(args.model)
            model = ResidualGNNs(args, train_data, args.hidden, args.hidden_mlp, args.num_layers, gnn).to(args.device) ## apply GNN*
            if (verbose):
                print(model)
            #total_params = sum(p.numel() for p in model.parameters())
            loss, test_acc = [], []
            best_val_auroc, best_val_loss = 0.0,0.0
            for epoch in range(args.epochs):
                print('args.epochs:', args.epochs)
                print('epoch:', epoch)
                loss = train(model, args, train_loader)
                val_metrics = test(model, args, val_loader)
                if (verbose):
                    train_metrics = test(model, args, train_loader)
                    test_metrics = test(model, args, test_loader)
                    print("epoch: {}, loss: {}, \ntrain_metrics:{}, \nval_metrics:{}, \ntest_metrics:{}".format(epoch, np.round(loss.item(),6), train_metrics, val_metrics, test_metrics))
                
                
                if val_metrics['auroc'] > best_val_auroc:
                    best_val_auroc = val_metrics['auroc']
                    torch.save(model.state_dict(), f"{checkpoints_dir}{args.dataset}_{args.edge_dir_prefix.split('/')[0]}_{args.model}{args.tune_name}task-checkpoint-best-auroc.pkl")

        #test the model
        model.load_state_dict(torch.load(f"{checkpoints_dir}{args.dataset}_{args.edge_dir_prefix.split('/')[0]}_{args.model}{args.tune_name}task-checkpoint-best-auroc.pkl"))
        model.eval()
        test_metrics = test(model, args, test_loader)
        fold_metrics.append(val_metrics)
        if (verbose):
            print(f"Fold {fold + 1} Test Metrics: {val_metrics}")

        if (verbose):
            print('test_metrics:', test_metrics)
    # Aggregate results
    avg_metrics = {key: np.mean([fold[key] for fold in fold_metrics]) for key in fold_metrics[0].keys()}
    
    if verbose:
        print(f"Average Metrics: {avg_metrics}")
    
    return avg_metrics

## Execution ##

### demo of full usage ###

In [7]:
argsDictTune_a = {
    # choose dataset form: ADNI(BOLD), HCP(CORR), BOLD+CORR
    'dataset' : "ADNI",
    # data path
    'dataset_dir' : "../../data/ADNI/", # ========================================= locally changed? =========================================
    # choose from: GCNConv, GINConv, SGConv, GeneralConv, GATConv
    'edge_dir_prefix' : [
        'pearson_correlation/pearson_correlation',
        "cosine_similarity/cosine_similarity",
        "KNN_Graph/knn_graph_",
        "Euclidean_Distance/distance_matrix_",
        "spearman_correlation/spearman_correlation",
        "kendall_correlation/kendall_correlation",
        "partial_correlation/partial_correlation",
        "cross_correlation/cross_correlation",        
        "pairwise_PC_aHOFC/aHOFC",
        "pairwise_PC_dHOFC/dHOFC",
        "pairwise_PC_tHOFC/tHOFC",
        "correlations_correlation/correlations_correlation",
        "associated_high_order_fc/associated_high_order_fc",
        # "mutual_information/mutual_information",
        # "granger_causality/granger_causality",
        #added——————————————————
        # "CityblockDistance/distance_matrix_",
        # "DTWDistance/DTW_distance_",
        # "EMDDistance/EMDdistance_matrix_",
        # "WaveletCoherence/coherence_matrix_",
        # "coherence_matrix/coherence_matrix",
        # "combined_correlation/combined_correlation",
        # "lingam/lingam",
        # "generalised_synchronisation_matrix/generalised_synchronisation_matrix",
        # "patels_conditional_dependence_measures_kappa/patels_conditional_dependence_measures_kappa",
        # "patels_conditional_dependence_measures_tau/patels_conditional_dependence_measures_tau",

    ],
    'model' : "GCNConv" ,
    'num_classes' : 2,  # ADNI - binary classification
    'weight_decay' : 0.0005,
    'batch_size' : [16, 32],
    'hidden_mlp' : 64,
    'hidden' : 32,
    'num_layers' : [2, 3, 4],
    'runs' : 1,
    'lr' : [1e-3, 1e-4, 1e-5],
    'epochs' : [100, 200, 300, 400, 500],
    'edge_percent_to_keep' : [0.05],
    'n_splits' : 5,
    'seed' : 42,
}
#print([x['lr'] for x in grid_from_param(argsdict)])
args_list_a = Args.tuning_list(argsDictTune_a)
fix_seed(args_list_a[0].seed)
test_metric_list_a = []
for args in args_list_a:
    met = bench_from_args(args, verbose=False)
    test_metric_list_a.append(met)
    print(args.tune_name)
    print(met)
    with open("result.txt", "a") as f:
        f.write(f"{args.tune_name}\n{met}\n--------------------------------\n")

full_data: [Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=0), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=0), Data(x=[100, 100], edge_index=[2, 500], edge_attr=[500], y=1), Data(x=[100, 100], edge_index=[2, 500], edg

KeyboardInterrupt: 