In [1]:
import os
import numpy as np
import scipy.io
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from utils import *  # NeuroGraph
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix

## Helper function for generating multiple sets of arguments from one object ##

In [2]:
# Helper function from one dictionary to multiple sets of params
def grid_from_param(dic): #, append_model_name = True, model_name_default = 'default'):
    # retrieve all lists to choose parameters from
    names = []
    lens = []
    for name in dic:
        item = dic[name]
        if(type(item)==list):
            names.append(name)
            lens.append(len(item))

    #helper
    count = 1
    mults = []
    for l in lens:
        mults.append(count)
        count *= l
    #print(mults)
    #print(count)
    #print(lens)

    #construct the grid
    params = []
    n = len(lens)
    for i in range(count):
        param = dic.copy()
        param['tune_name'] = '_'
        for j in range(n):
            param[names[j]] = dic[names[j]][i//mults[j] % lens[j]]
            param['tune_name'] += f"{names[j]}{param[names[j]]}_"
        params.append(param)
    return params

## Hyperparameters setting

In [3]:
argsDictTune = {
    # choose dataset form: ADNI(BOLD), HCP(CORR), BOLD+CORR
    'dataset' : "ADNI",
    # data path
    'dataset_dir' : "../../data/ADNI/", # ========================================= locally changed? =========================================
    # choose from: GCNConv, GINConv, SGConv, GeneralConv, GATConv
    #'edge_dir_prefix' : "pearsonpearson_/pearsonpearson_",
    'edge_dir_prefix' : "cosinecosine_/cosinecosine_",
    #'edge_dir_prefix' : "pairwise_PC_aHOFC/aHOFC",
    'model' : "GCNConv" ,
    'num_classes' : 2,  # ADNI - binary classification
    'weight_decay' : 0.0005,
    'batch_size' : 16,
    'hidden_mlp' : 64,
    'hidden' : 32,
    'num_layers' : 3,
    'runs' : 1,
    'lr' : [1e-3, 5e-4, 2e-4, 1e-4, 5e-5, 1e-5],
    #'lr' : [1e-3],
    'epochs' : 200,
    'seed' : 42,
}

class Args: # now it's just a wrapper for compatibility. Everything now packed up in the dictionary.
    def __init__(self, param_dict) -> None:
        # wrapped. see argDict above.
        self.dataset = param_dict['dataset']
        self.dataset_dir = param_dict['dataset_dir']
        self.edge_dir_prefix = param_dict['edge_dir_prefix']
        self.model = param_dict['model']
        self.num_classes = param_dict['num_classes']
        self.weight_decay = param_dict['weight_decay']
        self.batch_size = param_dict['batch_size']
        self.hidden_mlp = param_dict['hidden_mlp']
        self.hidden = param_dict['hidden']
        self.num_layers = param_dict['num_layers']
        self.runs = param_dict['runs']
        self.lr = param_dict['lr']
        self.epochs = param_dict['epochs']
        self.seed = param_dict['seed']
        self.device = "cuda" if self.model != "GATConv" else "cpu"
        self.tune_name = param_dict['tune_name'] if "tune_name" in param_dict else None
    def tuning_list(param_dicts : dict):
        p = grid_from_param(param_dicts)
        return [Args(x) for x in p]


#print([x['lr'] for x in grid_from_param(argsdict)])
args_list = Args.tuning_list(argsDictTune)
fix_seed(args_list[0].seed)

#print(len(Args.tuning_list(argsDictTune)))



## Reading our Datasets
use our HCP correlation matrix dataset, train/test split file, label file.

HCP data is downloaded from https://drive.google.com/drive/folders/166wCCtPOEL0O25FxzwB0I8AQA8b6Q9U1?usp=drive_link 

other files are in the data folder

use our ADNI dataset

In [4]:
def read_adni_data(args):
    fMRI_path = args.dataset_dir + "fmri_signal.mat"
    ICV_path = args.dataset_dir + "ICV.mat"
    AGE_path = args.dataset_dir + "AGE.mat"
    DX_path = args.dataset_dir + "DX.mat"
    gender_path = args.dataset_dir + "gender.mat"
    fMRI_data_path = args.dataset_dir + "fMRIdata_ADNI2_ADNI3.csv"
    # participants_path = r'./data/ADNI/participants.tsv'

    # read fMRI_path
    fmri_data = scipy.io.loadmat(fMRI_path)['fmri_signal']
    fMRI_data = [fmri_data[i][0] for i in range(len(fmri_data))]

    # read ICV_path
    icv_data = scipy.io.loadmat(ICV_path)['ICV']
    ICV_data = pd.DataFrame([icv_data[i][0] for i in range(len(icv_data))])

    # read AGE_path
    age_data = scipy.io.loadmat(AGE_path)['AGE']
    AGE_data = pd.DataFrame([age_data[i][0] for i in range(len(age_data))])

    # read gender_path
    gender_data = scipy.io.loadmat(gender_path)['gender']
    gender_data = pd.DataFrame([gender_data[i][0] for i in range(len(gender_data))])

    # read DX_path
    dx_data = scipy.io.loadmat(DX_path)['DX']
    DX_data = pd.DataFrame([dx_data[i][0] for i in range(len(dx_data))])

    # for all above variable, add a df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data))) to add Image_ID column
    for df in [ICV_data, AGE_data, gender_data, DX_data]:
        df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data)))

    # give their column names, EstimatedTotalIntraCranialVol, Age, Gender, Diagnosis
    ICV_data.columns = ['Image_ID', 'EstimatedTotalIntraCranialVol']
    AGE_data.columns = ['Image_ID', 'Age']
    gender_data.columns = ['Image_ID', 'Gender']
    DX_data.columns = ['Image_ID', 'Diagnosis']
    Image_ID = ICV_data['Image_ID']

    data_dict = {
        'fMRI_data': fMRI_data,
        'ICV_data': ICV_data,
        'AGE_data': AGE_data,
        'gender_data': gender_data,
        'DX_data': DX_data
    }
    return data_dict

In [5]:
def load_data_from_args(args:Args):

    # Label path
    labels_file = args.dataset_dir + 'y.csv'
    # Load labels
    labels_df = pd.read_csv(labels_file)

    # for ADNI Dataset
    if args.dataset == "ADNI":
        adni_data = read_adni_data(args)
        fMRI_data = adni_data['fMRI_data']
        ICV_data = adni_data['ICV_data']
        AGE_data = adni_data['AGE_data']
        gender_data = adni_data['gender_data']
        DX_data = adni_data['DX_data']

        # only keep healthy control and AD. namely 2 and 0
        labels_df = labels_df[labels_df['Diagnosis'].isin([2, 0])].reset_index(drop=True)
        # change all 2 to 1
        labels_df['Diagnosis'] = labels_df['Diagnosis'].replace({2: 1})

        dataset = []
        # traverse the labels_df by i
        for i in range(len(labels_df)):
            # print('i:', i)
            #print(i)
            IID = labels_df['IID'][i]
            y = labels_df['Diagnosis'][i]
            # turn y to <class 'torch.Tensor'>
            y = torch.tensor(y, dtype=torch.long)
            # z-score normalization for each column of each subject
            subject_data = fMRI_data[IID]
            # fill 0 with 1
            subject_data[subject_data == 0] = 1
            subject_data = (subject_data - np.mean(subject_data, axis=0)) / np.std(subject_data, axis=0)
            # x = torch.tensor(subject_data[:100, :], dtype=torch.float)
            #print(f"hello: {i}")

            try:
                edge_attr = pd.read_csv(args.dataset_dir + 'fmri_edge/' + args.edge_dir_prefix + str(IID) + '.csv')# ========================================= locally changed? =========================================
            except:
                print('File \"' + args.dataset_dir + 'fmri_edge/' + args.edge_dir_prefix + str(IID) + '.csv\" not found. Skipping.')
                continue
            x = torch.tensor(edge_attr.to_numpy(), dtype=torch.float)
            np.fill_diagonal(edge_attr.to_numpy(), 0)
            #print(f"read: {i}")

            # 获取10%最大的元素的阈值
            threshold = np.percentile(edge_attr, 90)
            
            # 只保留大于阈值的元素，其他置为0
            edge_attr[edge_attr < threshold] = 0

            edge_index = np.vstack(np.nonzero(edge_attr.to_numpy()))

            # only keep edge_attr with edge_index's value
            # 只保留与非零 edge_index 对应的 edge_attr
            filtered_edge_attr = edge_attr.to_numpy()[edge_index[0], edge_index[1]]

            # 确保 edge_attr 是一维张量
            filtered_edge_attr = torch.tensor(filtered_edge_attr, dtype=torch.float)
            #print(f"data creating: {i}")

            data = Data(x=x, edge_index=torch.tensor(edge_index, dtype=torch.long), edge_attr=filtered_edge_attr, y=y)
            #print(f"appending: {i}")


            # choose from the special case
            if data.edge_index[0].shape[0] != 1000:
                #print(f"skipping: {i}")
                continue

            dataset.append(data)

        # get train and test data
        train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
        return train_data, test_data
        
    # for HCP 
    elif args.dataset == "HCP":
        # train_ids = pd.read_csv(args.dataset_dir + 'ids_train.csv')['IID'].values
        # test_ids = pd.read_csv(args.dataset_dir + 'ids_test.csv')['IID'].values
        
        # train_data = [data for data in (load_mat_data(iid) for iid in train_ids) if data is not None]
        # test_data = [data for data in (load_mat_data(iid) for iid in test_ids) if data is not None]
        pass
    

## Train/Test Functions ##

In [6]:
def train(model, args: Args, train_loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.train()
    total_loss = 0
    for data in train_loader:  
        data = data.to(args.device)
        out = model(data) 
        loss = criterion(out, data.y) 
        total_loss +=loss
        loss.backward()
        optimizer.step() 
        optimizer.zero_grad()
    return total_loss/len(train_loader.dataset)

@torch.no_grad()
def test(model, args: Args, loader):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            out = model(data)
            probs = F.softmax(out, dim=1)  # Calculate probabilities
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy()[:, 1])  # Keep the probabilities of the positive class
            all_labels.append(data.y.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    auroc = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    
    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    
    metrics = {
        'accuracy': accuracy,
        'auroc': auroc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'f1_score': f1
    }
    
    return metrics

# test for multiclass
def test_multiclass(model, args: Args, loader):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            out = model(data)
            probs = F.softmax(out, dim=1)  # 计算所有类的概率
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels)
    
    # 计算指标
    accuracy = accuracy_score(all_labels, all_preds)
    # 如果你需要计算AUROC，对于多分类问题，可以使用平均方法
    auroc = roc_auc_score(all_labels, all_probs, multi_class='ovr')  # 'ovr'表示一对多（one-vs-rest）策略
    f1 = f1_score(all_labels, all_preds, average='weighted')  # 使用加权平均
    # confusion_matrix 需要转换为多分类版本
    cm = confusion_matrix(all_labels, all_preds)
    
    # sensitivity 和 specificity 的计算需要根据每个类分别计算
    sensitivities = []
    specificities = []
    for i in range(cm.shape[0]):
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - (tp + fn + fp)
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        sensitivities.append(sensitivity)
        specificities.append(specificity)
    
    metrics = {
        'accuracy': accuracy,
        'auroc': auroc,
        'sensitivity': np.mean(sensitivities),
        'specificity': np.mean(specificities),
        'f1_score': f1
    }
    
    return metrics

def bench_from_args(args: Args, verbose = False):
    # get train and test data
    # train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
    train_data, test_data = load_data_from_args(args)
    train_data, val_data = train_test_split(train_data, test_size=0.125, random_state=123)

    # create data loaders
    train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, args.batch_size, shuffle=False)
    test_loader = DataLoader(test_data, args.batch_size, shuffle=False)

    checkpoints_dir = './checkpoints/'
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)

    val_acc_history, test_acc_history, test_loss_history = [],[],[]
    #seed = 42
    for index in range(args.runs):
        gnn = eval(args.model)
        model = ResidualGNNs(args, train_data, args.hidden, args.hidden_mlp, args.num_layers, gnn).to(args.device) ## apply GNN*
        if (verbose):
            print(model)
        #total_params = sum(p.numel() for p in model.parameters())
        loss, test_acc = [], []
        best_val_auroc, best_val_loss = 0.0,0.0
        for epoch in range(args.epochs):
            loss = train(model, args, train_loader)
            val_metrics = test(model, args, val_loader)
            if (verbose):
                train_metrics = test(model, args, train_loader)
                test_metrics = test(model, args, test_loader)
                print("epoch: {}, loss: {}, \ntrain_metrics:{}, \nval_metrics:{}, \ntest_metrics:{}".format(epoch, np.round(loss.item(),6), train_metrics, val_metrics, test_metrics))
            
            
            if val_metrics['auroc'] > best_val_auroc:
                best_val_auroc = val_metrics['auroc']
                torch.save(model.state_dict(), f"{checkpoints_dir}{args.dataset}_{args.edge_dir_prefix.split('/')[0]}_{args.model}{args.tune_name}task-checkpoint-best-auroc.pkl")

        #test the model
        model.load_state_dict(torch.load(f"{checkpoints_dir}{args.dataset}_{args.edge_dir_prefix.split('/')[0]}_{args.model}{args.tune_name}task-checkpoint-best-auroc.pkl"))
        model.eval()
        # test_acc = test(model, args, test_loader)['accuracy']
        test_metrics = test(model, args, test_loader)
        # test_acc = test_metrics['accuracy']
        # test_loss = train(model, args, test_loader).item()
        # test_acc_history.append(test_acc)
        # test_loss_history.append(test_loss)
        if (verbose):
            print('test_metrics:', test_metrics)
        return test_metrics

## Execution ##

### demo of full usage ###

In [7]:
argsDictTune_a = {
    # choose dataset form: ADNI(BOLD), HCP(CORR), BOLD+CORR
    'dataset' : "ADNI",
    # data path
    'dataset_dir' : "../../data/ADNI/", # ========================================= locally changed? =========================================
    # choose from: GCNConv, GINConv, SGConv, GeneralConv, GATConv
    #'edge_dir_prefix' : "pearsonpearson_/pearsonpearson_",
    #'edge_dir_prefix' : "cosinecosine_/cosinecosine_",
    # 'edge_dir_prefix' : "pairwise_PC_aHOFC/aHOFC",
    # 'edge_dir_prefix' : "pairwise_PC_dHOFC/dHOFC",
    # 'edge_dir_prefix' : "pairwise_PC_tHOFC/tHOFC",
    # 'edge_dir_prefix' : "correlations_correlation/correlations_correlation",
    # 'edge_dir_prefix' : "partial_correlation/partial_correlation",
    'edge_dir_prefix' : "associated_high_order_fc/associated_high_order_fc",
    'model' : "GCNConv" ,
    'num_classes' : 2,  # ADNI - binary classification
    'weight_decay' : 0.0005,
    'batch_size' : 16,
    'hidden_mlp' : 64,
    'hidden' : 32,
    'num_layers' : 3,
    'runs' : 1,
    'lr' : [1e-3, 5e-4, 2e-4, 1e-4, 5e-5, 1e-5],
    #'lr' : [1e-3],
    'epochs' : 200,
    'seed' : 42,
}
#print([x['lr'] for x in grid_from_param(argsdict)])
args_list_a = Args.tuning_list(argsDictTune_a)
fix_seed(args_list_a[0].seed)
test_metric_list_a = []
for args in args_list_a:
    met = bench_from_args(args, verbose=False)
    test_metric_list_a.append(met)
    print(args.tune_name)
    print(met)

_lr0.001_
{'accuracy': 0.75, 'auroc': 0.8398481973434535, 'sensitivity': 0.788235294117647, 'specificity': 0.6451612903225806, 'f1_score': 0.8220858895705522}


KeyboardInterrupt: 