# Japanese subset_ main thesis project -ReHo

In [1]:
import pandas as pd
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, TensorDataset
from sklearn.model_selection import KFold
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score,confusion_matrix
import statistics
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import roc_curve
from sklearn.metrics import RocCurveDisplay


### Network architecture

In [2]:
#Define a Convolutional Neural Network : BASED ON https://www.biorxiv.org/content/10.1101/2019.12.17.879346v1

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        self.downsample = nn.AvgPool3d(2, stride=2, padding=0)
        
        self.CNNlayer = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=3, stride=1),
            nn.ELU(),
            nn.Conv3d(64, 16, kernel_size=3, stride=1),
            nn.ELU(),
            nn.MaxPool3d(2)
        )
        
        self.flat1 = nn.Linear(160000, 16)   
        self.flat2 = nn.Linear(16, 1)
            
    def forward(self, x):
        x=self.downsample(x)
        #print(f'avg-pool: {x.size()}\n----------')
        #print(f'number of nan in this layer = {torch.isnan(x).sum()}')
        
        x=self.CNNlayer(x)
        #print(f'convolution1+2+maxpool: {x.size()} \n----------')
        
        x=x.reshape(x.size(0), -1)
        #print(f'reshape after cnn: {x.size()}\n----------')
        
        x=F.elu(self.flat1(x))
        #print(f'fully-connected1: {x.size()}\n----------')
                    
        x=self.flat2(x)
        #print(f'fully-connected2: {x.size()}\n----------')
        
        return x

### lunch wandb

In [3]:
import wandb
!wandb login 390734ff44d817dbba59927d4eb542e564627b3b

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /data/zmohaghegh/.netrc


### funtion for preparaing summary measure for feeding to neural network 

In [4]:
def preparaing_summary_measure(mdd_test_site_reho_load,control_test_site_reho_load):
    
    # load numpy 
    control_reho_zero_nan2= control_test_site_reho_load
    mdd_reho_zero_nan2 = mdd_test_site_reho_load

    # create empty numpy 
    control_reho_zero_nan3 =np.array(control_reho_zero_nan2)
    mdd_reho_zero_nan3 = np.array(mdd_reho_zero_nan2)

    # zero_nan_control
    for i in range(len(control_test_site_reho_load)):    
        control_reho_zero_nan3[i][0] =np.nan_to_num(control_test_site_reho_load[i][0],copy=True)
        control_reho_zero_nan3[i][1] =np.nan_to_num(control_test_site_reho_load[i][1],copy=True)

    # zero_nan_mdd
    for i in range(len(mdd_test_site_reho_load)):    
        mdd_reho_zero_nan3[i][0] =np.nan_to_num(mdd_test_site_reho_load[i][0],copy=True)
        mdd_reho_zero_nan3[i][1] =np.nan_to_num(mdd_test_site_reho_load[i][1],copy=True)

    # add one channel to zero nan
    control_reho_4d_zero_nan = [[np.reshape(c[0], (1, 91, 109, 91)), c[1]] for c in control_reho_zero_nan3]
    mdd_reho_4d_zero_nan = [[np.reshape(m[0], (1, 91, 109, 91)), m[1]] for m in mdd_reho_zero_nan3]

    # concat mdd and contor data  
    dataset_Japan_test_reho_zero_nan= ConcatDataset([control_reho_4d_zero_nan, mdd_reho_4d_zero_nan])
    
    #dataset_Japan_test_reho_zero_nan[10][0][0,:,54,45]
    
    return dataset_Japan_test_reho_zero_nan

## Importing Japanese IDs for each site

In [5]:
COI_MDD=pd.read_csv('COI_MDD.csv')
UTO_MDD=pd.read_csv('UTO_MDD.csv')
HUH_MDD=pd.read_csv('HUH_MDD.csv')
HKH_MDD=pd.read_csv('HKH_MDD.csv')
HRC_MDD=pd.read_csv('HRC_MDD.csv')
KUT_MDD=pd.read_csv('KUT_MDD.csv')

MDD_sites=[COI_MDD,UTO_MDD,HUH_MDD,HKH_MDD,HRC_MDD,KUT_MDD]
MDD_sites_concat=pd.concat([COI_MDD,UTO_MDD,HUH_MDD,HKH_MDD,HRC_MDD,KUT_MDD])

In [6]:
COI_Control=pd.read_csv('COI_Control.csv')
UTO_Control=pd.read_csv('UTO_Control.csv')
HUH_Control=pd.read_csv('HUH_Control.csv')
HKH_Control=pd.read_csv('HKH_Control.csv')
HRC_Control=pd.read_csv('HRC_Control.csv')
KUT_Control=pd.read_csv('KUT_Control.csv')

Control_sites=[COI_Control,UTO_Control,HUH_Control,HKH_Control,HRC_Control,KUT_Control]
Control_sites_concat=pd.concat([COI_Control,UTO_Control,HUH_Control,HKH_Control,HRC_Control,KUT_Control])

In [7]:
mdd_base_path = '/dbstore/zmohaghegh/Japanese_subset/summary_measures/MDD_reho/ReHo_Normalised_z/'
control_base_path = '/dbstore/zmohaghegh/Japanese_subset/summary_measures/Control_reho/ReHo_Normalised_z/'

In [8]:
#len(mdd_test_site_reho_load)
#len(mdd_test_site_reho_load[0])
#mdd_test_site_reho_load[0][0].shape
#test_dataset_cv =  preparaing_summary_measure(mdd_test_site_reho_load, mdd_test_site_reho_load) 
#print(len(test_dataset_cv))
#test_dataset_cv[0][0].shape
#len(MDD_sites_concat['participants_id'])

In [24]:
site_num=1

In [25]:
Control_test= Control_sites[site_num] 
ID_test_Control=[]
for j in range(len(Control_test['participants_id'])):
    ID_test_Control.append(Control_test['participants_id'][j])
print(f'test Control dataset size : {len(ID_test_Control)}')

test Control dataset size : 62


In [26]:
ID_train_Control=[]
for ids in Control_sites_concat['participants_id']:
    if ids not in ID_test_Control:
        ID_train_Control.append(ids)
print(f'train Control dataset size : {len(ID_train_Control)}')

train Control dataset size : 189


## Leave on site out : Cross validation loop for each Site 

In [9]:
for site_num, site in enumerate(MDD_sites):
    site_name= site['site'][0]
    
    print(site_name)
    ##################### load test ######################
    print('loading test dataset')
    
    ID_test_MDD=[]
    for i in range(len(site['participants_id'])):
        ID_test_MDD.append(site['participants_id'][i])
    print(f'test MDD dataset size : {len(ID_test_MDD)}')
    
    Control_test= Control_sites[site_num] 
    ID_test_Control=[]
    for j in range(len(Control_test['participants_id'])):
        ID_test_Control.append(Control_test['participants_id'][j])
    print(f'test Control dataset size : {len(ID_test_Control)}')
    
    test_site_mdd_file_path = [mdd_base_path + f'ReHo_z_{test_ids_MDD}.nii' for test_ids_MDD in ID_test_MDD]
    #test_site_control_file_path = [control_base_path + f'ReHo_z_{test_ids_Control}.nii' for test_ids_Control in ID_test_Control]

    mdd_test_site_reho_load = [[nib.load(m).get_fdata(),1] for m in test_site_mdd_file_path]
    #control_test_site_reho_load = [[nib.load(c).get_fdata(),0] for c in test_site_control_file_path]
    
    # prepare test data for feeding to network
    print('concating train dataset')
    test_dataset_cv =  preparaing_summary_measure(mdd_test_site_reho_load, mdd_test_site_reho_load) 

    ##################### load train ######################
    print('loading train dataset')
    
    ID_train_MDD=[]
    for ids in MDD_sites_concat['participants_id']:
        if ids not in ID_test_MDD:
            ID_train_MDD.append(ids)
    print(f'train MDD dataset size : {len(ID_train_MDD)}')
    
    ID_train_Control=[]
    for ids in Control_sites_concat['participants_id']:
        if ids not in ID_test_Control:
            ID_train_Control.append(ids)
    print(f'train Control dataset size : {len(ID_train_Control)}')
            
    
    train_site_mdd_file_path = [mdd_base_path + f'ReHo_z_{train_ids_MDD}.nii' for train_ids_MDD in ID_train_MDD]
    #train_site_control_file_path = [control_base_path + f'ReHo_z_{train_ids_Control}.nii' for train_ids_Control in ID_train_Control]
    
    mdd_train_site_reho_load = [[nib.load(m).get_fdata(),1] for m in train_site_mdd_file_path]
    #control_train_site_reho_load = [[nib.load(c).get_fdata(),0] for c in train_site_control_file_path]
    
    # prepare train data for feeding to network
    print('concating train dataset')
    train_dataset_cv =  preparaing_summary_measure(mdd_train_site_reho_load,mdd_train_site_reho_load)
    
    ######################################################################################################
    
    wandb.init(project='Leave-one-site-out-japanese-reho')


    ############################### train and validation loop CV  ########################################
    k_folds = 5
    kfold_results = {}
    kfold = KFold(n_splits=k_folds, shuffle=True)

    #torch.manual_seed(42)
    num_epochs = 1
    batch_size = 10
    learning_rate= 0.001


    #Define a Loss function 
    loss_function = nn.BCEWithLogitsLoss()

    for fold, (train_ids, valid_ids) in enumerate(kfold.split(train_dataset_cv)):
        best_loss_cv= None

        print(f"FOLD {fold}\n--------------------------------")

        # Sample elements randomly from a given list of ids,
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_ids)

        # Define data loaders for training and testing data in this fold
        train_loader = torch.utils.data.DataLoader(train_dataset_cv, batch_size=batch_size, sampler=train_subsampler)
        valid_loader = torch.utils.data.DataLoader(train_dataset_cv, batch_size=batch_size, sampler=valid_subsampler)

        #define network
        network = ConvNet()
        network = network.double()

        # create our optimizer
        optimizer = optim.SGD(network.parameters(), momentum=0.9, lr = learning_rate, weight_decay=1e-3)

        # in the training loop:

        network.train() # prepare model for training

        for epoch in range(0, num_epochs):
            print(f'*********Starting epoch {epoch+1}')

            train_loss_cv = 0
            total =0
            correct=0

            # train model/network 
            for i, data in enumerate(train_loader, 0):
                #print(f'train {i}')

                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data

                # zero the gradient buffers
                optimizer.zero_grad()

                # forward pass
                outputs = network(inputs)

                # print(outputs.size)
                outputss=outputs.squeeze(1) #### [10,1] ---> [10]

                # prediction 
                predicted = outputss.data > 0.0

                labels=labels.double()

                #calcuate loss/error
                loss = loss_function(outputss, labels)

                # backward pass
                loss.backward()

                # Does the update , gradient descent
                optimizer.step() 

                correct += (predicted == labels).sum().item()
                train_loss_cv += loss.item()
                total += labels.size(0)

            wandb.log({"epoch_cv": epoch , "train_Loss_cv": train_loss_cv/total, "train_acc_cv": 100 * correct / total })

            print(f'train loss= {train_loss_cv/total}')
            print(f'train Acc= {100 * correct / total}')


            print('Training process has finished.')
            print('Starting testing')

            # validate the network using the validation data, for this fold
            correct= 0
            total = 0
            valid_loss_cv=0

            network.eval()

            with torch.no_grad():
                for i, data in enumerate(valid_loader, 0):
                    #print(f'test {i}')

                    inputs, lables = data

                    outputs = network(inputs)

                    outputss=outputs.squeeze(1) #[10,1] ---> [10]
                    lables=lables.double()

                    # prediction 
                    predicted = outputss.data > 0.0

                    loss = loss_function(outputss, lables)

                    valid_loss_cv += loss.item()
                    total += lables.size(0)
                    correct += (predicted == lables).sum().item()

                wandb.log({ "validation_acc_cv": 100 * correct /total, "validation_Loss_cv": valid_loss_cv/total })

                current_valid_loss_cv = valid_loss_cv/total

                print(f'valid_acc :{100 * correct /total}')
                print(f'valid_loss: {valid_loss_cv/total}')

                if not best_loss_cv or best_loss_cv > current_valid_loss_cv:
                    best_loss_cv = current_valid_loss_cv

                    print('Saving best valid -trained model.')
                    path_best_loss_reho = f'/data/zmohaghegh/TempStats_3D-CNN/leave_one_site_out_best_model_reho/model-japanese_best-reho-fold-{fold}.pth'
                    torch.save(network.state_dict(), path_best_loss_reho )

                print('validation process has finished')

        print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
        print('--------------------------------')

        kfold_results[fold] = 100.0 * (correct / total)



    ##############################Result of Validation for each fold ##################################

    print(f"K-FOLD CROSS VALIDATION RESULTS FOR japanese reho {k_folds} FOLDS\n--------------------------------")
    _sum = 0.0

    for key, value in kfold_results.items():
        print(f'Fold {key}: {value} %')
        _sum += value

    print(f'Average: {_sum/len(kfold_results.items())} %')

    ############################ ### Test loop for Cross validation  #############################################

    #wandb.init(project='Leave-one-site-out-japanese-reho')
    
    test_loader  = torch.utils.data.DataLoader(test_dataset_cv , batch_size=batch_size, shuffle=False)

    bal_acc_fold=[]
    F1_score_fold=[]

    for k in np.arange(5): 

        print(f'Start TEST FOR FOLD {k}')

        path_fold = f'/data/zmohaghegh/TempStats_3D-CNN/leave_one_site_out_best_model_reho/model-japanese_best-reho-fold-{k}.pth'

        network.load_state_dict(torch.load(path_fold))

        test_loss_cv=0
        total = 0
        correct=0

        F1_labels=[]
        F1_pred=[]

        network.eval() # preoare model for test and evaluation

        with torch.no_grad():
            #print('Start testing CV...')
            for i, data in enumerate(test_loader, 0):
                #print(f'test {i}')

                inputs, lables = data

                outputs = network(inputs)

                lables=lables.double()
                outputss=outputs.squeeze(1) #[10,1] ---> [10]

                #do the prediction 
                predicted = outputss.data > 0.0

                #calculate loss
                loss = loss_function(outputss, lables)

                test_loss_cv += loss.item()
                correct += (predicted == lables).sum().item()
                total += lables.size(0)

                if i==0:
                    F1_labels=lables.int().numpy()
                    F1_pred=predicted.int().numpy()
                else:
                    F1_labels= np.concatenate((F1_labels, lables.int().numpy()))
                    F1_pred = np.concatenate((F1_pred, predicted.int().numpy()))

            wandb.log({ "test_Acc_CV": 100 * correct /total  , "test_Loss_CV": test_loss_cv/total })

            acc = accuracy_score(F1_labels, F1_pred)
            bal_acc= balanced_accuracy_score(F1_labels, F1_pred)

            F1_Score = f1_score(F1_labels, F1_pred, average='weighted')
            #tn, fp, fn, tp = confusion_matrix(F1_labels, F1_pred).ravel()

            F1_score_fold.append(F1_Score)
            bal_acc_fold.append(bal_acc)

            #PLotting Confusion matrix and ROC curve

            #conf_matrix = confusion_matrix(F1_labels, F1_pred)
            #conf_matrix_display = ConfusionMatrixDisplay(conf_matrix).plot()

            #fp_rate, tp_rate, threshold = roc_curve(F1_labels, F1_pred)
            #ROC_display = RocCurveDisplay(fpr=fp_rate, tpr=tp_rate).plot()


            wandb.log({ "test_balanced_Acc_CV": bal_acc , "test_Acc_CV": 100 * correct /total, "test_F1_score_CV": F1_Score , "test_Loss_CV": test_loss_cv/total })

            print(f'test_Acc_CV": {100 * correct /total}')
            print(f'F1_score CV ReHo :{F1_Score}')
            print(f'Balanced ACC CV ReHo :{bal_acc}')
            print(f'Loss CV ReHo : {test_loss_cv/total}')



    ############################ ### Result of test loop average  #############################################       
    F1_score_avg= sum(F1_score_fold)/len(F1_score_fold)
    F1_score_std= statistics.pstdev(F1_score_fold)
    bal_acc_avg = 100 * (sum(bal_acc_fold)/len(bal_acc_fold))
    bal_acc_std = 100* statistics.pstdev(bal_acc_fold)

    print(f' Site name : {site_name}')
    print(f' @@@@@@@@@@ Average Balance ACC japan ReHo = {bal_acc_avg}')
    print(f'standard deviation :{bal_acc_std}')
    print(f' @@@@@@@@@@@ Average F1_score japan ReHo = { F1_score_avg}')
    print(f'standard deviation :{F1_score_std}')
    
    wandb.log({ "SITE": site_name, "test_balanced_Acc_Average": bal_acc_avg, "test_F1_score_CV":F1_score_avg})



COI
loading test dataset
test MDD dataset size : 71
concating train dataset


  
  if __name__ == '__main__':


loading train dataset
train MDD dataset size : 184
concating train dataset


[34m[1mwandb[0m: Currently logged in as: [33mzahramhn[0m (use `wandb login --relogin` to force relogin)


FOLD 0
--------------------------------
*********Starting epoch 1


  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag


train loss= 0.006837654547152946
train Acc= 100.0
Training process has finished.
Starting testing
valid_acc :100.0
valid_loss: 5.179052929375922e-05
Saving best valid -trained model.
validation process has finished
Accuracy for fold 0: 100 %
--------------------------------
FOLD 1
--------------------------------
*********Starting epoch 1
train loss= 0.005770599559262714
train Acc= 96.59863945578232
Training process has finished.
Starting testing
valid_acc :100.0
valid_loss: 5.152145186582089e-07
Saving best valid -trained model.
validation process has finished
Accuracy for fold 1: 100 %
--------------------------------
FOLD 2
--------------------------------
*********Starting epoch 1
train loss= 0.005716263624398767
train Acc= 96.59863945578232
Training process has finished.
Starting testing
valid_acc :100.0
valid_loss: 5.615741433182553e-07
Saving best valid -trained model.
validation process has finished
Accuracy for fold 2: 100 %
--------------------------------
FOLD 3
------------

  
  if __name__ == '__main__':


loading train dataset
train MDD dataset size : 193
concating train dataset


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
SITE,COI
epoch_cv,0
train_Loss_cv,0.00551
train_acc_cv,96.61017
_runtime,983
_timestamp,1619824210
_step,20
validation_acc_cv,100.0
validation_Loss_cv,0.0
learning rate,0.001


0,1
epoch_cv,▁▁▁▁▁
train_Loss_cv,▃▁▁█▁
train_acc_cv,█▁▁▄▁
_runtime,▁▁▂▂▄▄▅▅▇▇▇▇▇▇▇▇█████
_timestamp,▁▁▂▂▄▄▅▅▇▇▇▇▇▇▇▇█████
_step,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
validation_acc_cv,▁▁▁▁▁
validation_Loss_cv,▆▁▁█▁
learning rate,▁▁▁▁▁
test_Acc_CV,▁▁▁▁▁▁▁▁▁▁


FOLD 0
--------------------------------
*********Starting epoch 1
train loss= 0.005357299153040962
train Acc= 96.75324675324676
Training process has finished.
Starting testing
valid_acc :100.0
valid_loss: 3.2347506466703426e-06
Saving best valid -trained model.
validation process has finished
Accuracy for fold 0: 100 %
--------------------------------
FOLD 1
--------------------------------
*********Starting epoch 1
train loss= 0.004386184571425297
train Acc= 100.0
Training process has finished.
Starting testing
valid_acc :100.0
valid_loss: 1.9539277590568118e-07
Saving best valid -trained model.
validation process has finished
Accuracy for fold 1: 100 %
--------------------------------
FOLD 2
--------------------------------
*********Starting epoch 1


KeyboardInterrupt: 