### This notebok implements ASD-DiagNet using the following features computed from the multiple linear regression and ComBat harmonization models, and normalization methods using static functional connectivity (sFC), computed from the ABIDE rs-FMRI preprocessed data, as dependent variables, and cc200 as the brain atlas:

1. sFC: static functional connectivity (sfc_feature,file: sfc_feature_file_cc200.pkl ),
2. $\Delta$ mlrA: mlr residual of sFC with age as independent variable (sfc_mlr_age_feature, file: sfc_mlr_age_feature_file_cc200.pkl)
3. $\Delta$ mlrA$_{FZ}$: mlr residual of the Fisher Z-transform of sFC(sFC$_{FZ}$) with age as independent variable (sfc_fz_mlr_age_feature, file: sfc_fz_mlr_age_feature_file_cc200.pkl)
4. $\Delta$ mlrF: mlr residual of sFC with FIQ as independent variable (sfc_mlr_FIQ_feature, file: sfc_mlr_FIQ_feature_file_cc200.pkl)
5. $\Delta$ mlrF$_{FZ}$:mlr residual of the Fisher Z-transform of sFC(sFC$_{FZ}$) with FIQ as independent variable (sfc_fz_mlr_FIQ_feature, file: sfc_fz_mlr_FIQ_feature_file_cc200.pkl)
6. $\Delta$ mlrM: mlr residual of sFC with MRI vendor as independent variable (sfc_mlr_MRI_feature, file: sfc_mlr_MRI_feature_file_cc200.pkl)
7. $\Delta$ mlrM$_{FZ}$:mlr residual of the Fisher Z-transform of sFC(sFC$_{FZ}$) with MRI as independent variable (sfc_fz_mlr_MRI_feature, file: sfc_fz_mlr_MRI_feature_file_cc200.pkl)
8. $\Delta$ mlrG: mlr residual of sFC with gender as independent variable (sfc_mlr_gender_feature, file: sfc_mlr_MRI_feature_file_cc200.pkl)
9. $\Delta$ mlrG$_{FZ}$:mlr residual of the Fisher Z-transform of sFC(sFC$_{FZ}$) with gender as independent variable (sfc_fz_mlr_gender_feature, file: sfc_fz_mlr_gender_feature_file_cc200.pkl)
10. $\Delta$ mlrAGM: mlr residual of sFC with age,gender and MRI vendor as independent variables (sfc_mlr_AGM_feature, file: sfc_mlr_AGM_feature_file_cc200.pkl)
11. $\Delta$ mlrAGM$_{FZ}$:mlr residual of the Fisher Z-transform of sFC(sFC$_{FZ}$) with age,gender and MRI vendor as independent variables (sfc_fz_mlr_AGM_feature, file: sfc_fz_mlr_AGM_feature_file_cc200.pkl)
12. cbA: ComBat harmonization of SFC with age as independent variable (sfc_combat_age_feature, file: sfc_combat_age_feature_file_cc200.pkl)
13. cbA$_{FZ}$: ComBat harmonization of the Fisher Z-transform of sFC(sFC$_{FZ}$) with age as independent variable (sfc_fz_combat_age_feature, file: sfc_fz_combat_age_feature_file_cc200.pkl)
14. cbF: ComBat harmonization of SFC with FIQ as independent variable (sfc_combat_FIQ_feature, file: sfc_combat_age_feature_file_cc200.pkl)
15. cbF$_{FZ}$: ComBat harmonization of the Fisher Z-transform of sFC(sFC$_{FZ}$) with FIQ as independent variable (sfc_fz_combat_FIQ_feature, file: sfc_fz_combat_FIQ_feature_file_cc200.pkl)
16. cbAFG: ComBat harmonization of SFC with age, FIQ and gender as independent variables (sfc_combat_AFG_feature, file: sfc_combat_AFG_feature_file_cc200.pkl)
17. cbAFG$_{FZ}$: ComBat harmonization of the Fisher Z-transform of sFC(sFC$_{FZ}$) with age, FIQ and gender as independent variables (sfc_fz_combat_AFG_feature, file: sfc_fz_combat_AFG_feature_file_cc200.pkl)
18. $\Delta$ avg: demeaning of the  sFC with the average of sFC over all ABIDE subjects (sfc_res_avg_feature, file: sfc_res_avg_feature_file_cc200.pkl)
19. $\Delta$ avgSite: demeaning of the sFC on a given site  with the average of sFC over all ABIDE subjects for the given site (sfc_res_avg_site_feature, file: sfc_res_avg_site_feature_file_cc200.pkl)
20. $\Delta$ avgSubj: demeaning of the sFC with the average of sFC computed for each ABIDE subject  (sfc_res_avg_subj_feature, file: sfc_res_avg_subj_feature_file_cc200.pkl)


### This notebook also implements ASD-DiagNet using baseline sub-samples of the ABIDE sites and homogeneous sub-samples of the ABIDE subjects. 

#### The baseline sub-samples were formed by progressively selecting the sites with the greatest values of accuracy computed with ASD-DiagNet for the classification of control and autistic subjects, with sFC features and the cc200 as the brain atlas. The baseline sub-samples are:

bss_4 was formed with the 4 sites which obtained accuracies $\ge$ 70.0: 

bss_4 = ['KKI','OHSU','Olin','USM'] 

bss_5 was formed with the 5 sites which obtained accuracies $\ge$ 66.8: 

bss_5 = ['KKI','NYU','OHSU','Olin','USM'] 

bss_6 was formed with the 6 sites which obtained accuracies $\ge$ 66.4:

bss_6 =['KKI','NYU','OHSU','Olin','UCLA',
        'USM'] 

bss_7 was formed with the 7 sites which obtained accuracies > 64.6: 

bss_7 =['KKI','NYU','OHSU','Olin','UCLA',
        'USM','Yale'] 

bss_8 was formed with the 8 sites which obtained accuracies > 63.9: 

bss_8 =['KKI','NYU','OHSU','Olin','Stanford',
        'UCLA','USM','Yale'] 

bss_9 was formed with the 9 sites which obtained accuracies > 63.8: 

bss_9 = ['CMU','KKI','NYU','OHSU','Olin',
         'Stanford','UCLA','USM','Yale'] 

bss_10 was formed with the 10 sites which obtained accuracies > 63.4: 

bss_10 = ['CMU','KKI','NYU','OHSU','Olin',
          'Stanford','UCLA','UM','USM','Yale']
 
bss_11 was formed with the 11 sites which obtained accuracies > 62.4:

bss_11 =['CMU','KKI','Leuven','NYU','OHSU',
         'Olin','Stanford','UCLA','UM','USM',
         'Yale'] 

bss_12 was formed with the 12 sites which obtained accuracies > 61.4: 

bss_12 =['CMU','KKI','Leuven','NYU','OHSU',
         'Olin','Pitt','Stanford','UCLA','UM',
         'USM','Yale'] 

bss_13 was formed with the 13 sites which obtained accuracies > 55.9: 

bss_13 =['CMU','KKI','Leuven','NYU','OHSU',
         'Olin','Pitt','SDSU','Stanford','UCLA',
         'UM','USM','Yale'] 
         
bss_14 was formed with the 14 sites which obtained accuracies > 55.0: 

bss_14 =['CMU','KKI','Leuven','NYU','OHSU',
         'Olin','Pitt','SBL','SDSU','Stanford',
         'UCLA','UM','USM','Yale'] 

bss_15 was formed with the 15 sites which obtained accuracies > 54.0: 

bss_15 =['CMU','KKI','Leuven','MaxMun','NYU',
         'OHSU','Olin','Pitt','SBL','SDSU',
         'Stanford','UCLA','UM','USM','Yale'] 

bss_16 was formed with the 16 sites which obtained accuracies > 52.1: 

bss_16 =['Caltech','CMU','KKI','Leuven','MaxMun',
         'NYU','OHSU','Olin','Pitt','SBL',
          'SDSU','Stanford','UCLA','UM','USM',
          'Yale'] 

whole designates all the 17 sites:

whole = ['Caltech','CMU','KKI','Leuven','MaxMun',
         'NYU','OHSU','Olin','Pitt','SBL',
         'SDSU','Stanford','Trinity','UCLA','UM',
          'USM','Yale']

#### The homogeneous sub-samples of the ABIDE subjects were integrated by subjects classified by ranges of age, and ranges of FIQ. The  homogeneous sub-samples are:

hss_age_10 was formed with 12  sites with subjects for which 0< age<=10: 

hss_age_10 = ['KKI','MaxMun','NYU','OHSU','Olin',
              'Pitt','SDSU','Stanford','UCLA','UM',
              'USM','Yale'] 

hss_age_1015 was formed with 14  sites with subjects for which 10 < age<=15: 

hss_age_1015 = ['KKI','Leuven','MaxMun','NYU','OHSU',
                'Olin','Pitt','SDSU','Stanford','Trinity',
                'UCLA','UM','USM','Yale'] 

hss_age_1520 was formed with 15 sites with subjects for which 15 < age<=20:

hss_age_1520 = ['Caltech','CMU','Leuven','MaxMun','NYU',
                'OHSU','Olin','Pitt','SBL','SDSU',
                'Trinity','UCLA','UM','USM','Yale'] 

hss_age_1020 (age_1020) was formed with 17  sites with subjects for which 10 < age<=20:

hss_age_1020 = ['Caltech','CMU','KKI','Leuven','MaxMun',
                'NYU','OHSU','Olin','Pitt','SBL',
                'SDSU','Stanford','Trinity','UCLA','UM',
                'USM','Yale'] 

hss_age_20 was formed with 11  sites with subjects for which age>20:

hss_age_20  = ['Caltech','CMU','Leuven','MaxMun','NYU',
               'Olin','Pitt','SBL','Trinity','UM','USM'] 
               
hss_FIQ_89 was formed with 14 sites with subjects for which 0 < FIQ <=89, p_fold = 5:

hss_FIQ_89 = ['KKI','Leuven','MaxMun','NYU','OHSU',
              'Olin','Pitt','SDSU','Stanford','Trinity',
              'UCLA','UM','USM','Yale']

hss_FIQ_89_110 was formed with 16 sites with subjects for which 89 < FIQ <=110:

hss_FIQ_89_110 = ['Caltech','CMU','KKI','Leuven','MaxMun',
                  'NYU','OHSU','Olin','Pitt','SDSU',
                  'Stanford','Trinity','UCLA','UM','USM',
                  'Yale'] 

hss_FIQ_110  was formed with 16 sites with subjects for which FIQ > 110:

hss_FIQ_110 = ['Caltech','CMU','KKI','Leuven','MaxMun',
               'NYU','OHSU','Olin','Pitt','SDSU',
               'Stanford','Trinity','UCLA','UM','USM',
               'Yale']

hss_age_1020_FIQ_89110 was formed with 15 sites with subjects for which 89 < FIQ <=110 
and 10<age<=20:

hss_age_1020_FIQ_89110  = ['CMU','KKI','Leuven','MaxMun''NYU',
                           'OHSU','Olin','Pitt','SDSU','Stanford',
                           'Trinity','UCLA','UM','USM','Yale'] 
                           
hss_age_1020_FIQ_89 was formed with 14 sites with subjects for which 0< FIQ <=89 and 10<age<=20,
p_fold = 5:

hss_age_1020_FIQ_89 = ['KKI','Leuven','MaxMun','NYU','OHSU',
                       'Olin','Pitt','SDSU','Stanford','Trinity',
                       'UCLA','UM','USM','Yale'] 

hss_FIQ_89_bal was formed with 14 sites with subjects for which 0<IQ<=89, plus a 
number of control subjects out of the sub-sample to balance it:

hss_FIQ_89_bal = ['KKI','Leuven','MaxMun','NYU','OHSU',
                  'Olin','Pitt','SDSU','Stanford','Trinity',
                  'UCLA','UM','USM','Yale']
                  
hss_age_1020_FIQ_89_bal was formed with 14 sites with subjects for which 0< IQ <=89 and 10<age<=20,
plus a number of control subjects out of the sub-sample to balance it:

hss_age_1020_FIQ_89_bal  = ['KKI','Leuven','MaxMun','NYU','OHSU',
                            'Olin','Pitt','SDSU','Stanford','Trinity',
                            'UCLA','UM','USM','Yale']                   

                          
#### The starting parameters to run this notebook are defined as follows:

1) Brain atlas: p_ROI = ['cc200', 'aal', 'dosenbach160','ez','ho','tt']. Notice that a brain atlas diferent that cc200 can be used, but the new features corresponding to the new brain atlas need to be computed using the sfcfeatures.py module.  

2) feature key: p_feature = [‘sfc','sfc_mlr_age','sfc_fz_mlr_age','sfc_mlr_FIQ',
                             'sfc_fz_mlr_FIQ','sfc_mlr_MRI','sfc_fz_mlr_MRI',
                             'sfc_mlr_gender','sfc_fz_mlr_gender','sfc_mlr_AGM','sfc_fz_mlr_AGM',
                             'sfc_combat_age','sfc_fz_combat_age','sfc_combat_FIQ',
                             'sfc_fz_combat_FIQ','sfc_combat_AFG', 'sfc_fz_combat_AFG',
                             'sfc_res_avg','sfc_res_avg_site','sfc_res_avg_subj']
                             
             

                            
3) p-fold cross-validation: p_fold = 5 for one site computation, 
                            p_fold = 10 for more than one site computation
                            
4) Define site or sub-sample with:

p_center = ['Caltech','CMU','KKI','Leuven','MaxMun','NYU','OHSU', 'Olin',
            'Pitt','SBL','SDSU','Stanford','UCLA','UM','USM','Yale','Trinity',
            'bss_4','bss_5','bss_6','bss_7','bss_8','bss_9','bss_10','bss_11',
            'bss_12','bss_13','bss_14','bss_15','bss_16','whole',
            'hss_age_10','hss_age_1015','hss_age_1520','hss_age_1020',
            'hss_age_20','hss_FIQ_89','hss_FIQ_89_110','hss_FIQ_110',
            'hss_age_1020_FIQ_89110','hss_age_1020_FIQ_89','hss_FIQ_89_bal',
            'hss_age_1020_FIQ_89_bal']
            

 
5) Classification mode as site or sub-sample: 
   p_mode = ['site', 'baseline sub-sample', 'homogeneous sub-sample']

    Example: if p_center = 'Caltech', then p_mode = 'site',
             if p_center = 'bss_4', then p_mode = 'baseline sub-sample', 
             if p_center = 'hss_age_10', then p_mode = 'baseline sub-sample'
             
6) Utilizing augmentation technique: p_augmentation = [False. True]

7) Utilizing shuffle technique: p_shuffle = [False. True]

8) p_max_add_control: maximum number of additional control subjects to balance an homogenous subsample.

9) p_ss_length = subsamples_length[p_center]: subsample length

### Classes implemented in the module asdiagnetconf.py:

1. MultiSiteData

2. SubSamples a subclass of the class MultiSiteData

3. HelperFunctions a subclass of the class MultiSiteData

4. MTAutoEncoder a subclass of the class nn.Module

5. DiagDataLoader a subclass of the class HelperFunctions

Repository:  https://github.com/pcdslab/ASD-DiagNet-Confounds

In [None]:
# Possibility to stop warnings
import warnings

warnings.filterwarnings('ignore') 

In [None]:
import pandas as pd
import numpy as np
#import matplotlib.pyplot as plt
import os
from functools import reduce
import time
import torch.utils.tensorboard
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import pyprind
import sys
import pickle
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy import stats
import functools
import numpy.ma as ma # for masked arrays
import pyprind
import random

# sklearn library
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import train_test_split

#asdiagnet module
from asdiagnetconf import MultiSiteData as MSD
from asdiagnetconf import SubSamples as SS
from asdiagnetconf import HelperFunctions as HF
from asdiagnetconf import MTAutoEncoder as MTA
from asdiagnetconf import DiagDataLoader as DL

In [None]:
center_list = ['Caltech','CMU','KKI','Leuven','MaxMun','NYU','OHSU', 'Olin', 'Pitt','SBL','SDSU','Stanford',
            'Trinity','UCLA','UM','USM','Yale', 'bss_4','bss_5','bss_6','bss_7','bss_8','bss_9','bss_10',
            'bss_11', 'bss_12','bss_13','bss_14','bss_15','bss_16','whole', 'hss_age_10','hss_age_1015',
            'hss_age_1520','hss_age_1020', 'hss_age_20','hss_FIQ_89','hss_FIQ_89_110','hss_FIQ_110', 
            'hss_age_1020_FIQ_89110','hss_age_1020_FIQ_89','hss_FIQ_89_bal','hss_age_1020_FIQ_89_bal']

In [None]:
subsamples_length = {'bss_4':4,'bss_5':5,'bss_6':6,'bss_7':7,'bss_8':8,'bss_9':9,'bss_10':10,
                  'bss_11':11,'bss_12':12,'bss_13':13,'bss_14':14,'bss_15':15,'bss_16':16,
                  'whole':17,'hss_age_10':12,'hss_age_1015':14,'hss_age_1520':15,'hss_age_1020':17,
                  'hss_age_20':11,'hss_FIQ_89':14,'hss_FIQ_89_110':17,'hss_FIQ_110':17, 
                  'hss_age_1020_FIQ_89110':17,'hss_age_1020_FIQ_89':14,'hss_FIQ_89_bal':14,
                   'hss_age_1020_FIQ_89_bal':14}

In [None]:
feature_list = ['sfc','sfc_mlr_age','sfc_fz_mlr_age','sfc_mlr_FIQ', 'sfc_fz_mlr_FIQ','sfc_mlr_MRI',
             'sfc_fz_mlr_MRI', 'sfc_mlr_gender','sfc_fz_mlr_gender','sfc_mlr_AGM','sfc_fz_mlr_AGM', 
             'sfc_combat_age','sfc_fz_combat_age','sfc_combat_FIQ', 'sfc_fz_combat_FIQ','sfc_combat_AFG', 
             'sfc_fz_combat_AFG', 'sfc_res_avg','sfc_res_avg_site','sfc_res_avg_subj']

In [None]:
mode_list = ['site', 'baseline sub-sample', 'homogeneous sub-sample']

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

### Starting parameters

In [None]:

p_ROI = 'cc200'
p_feature =  feature_list[0] 
p_center  =  center_list[0] 
p_mode    =  mode_list[0]
if p_mode ==  'site':
    p_fold = 5
    p_shuffle = True
else: 
    p_fold = 10
    p_shuffle = True
p_augmentation = True
if p_center == 'hss_FIQ_89' or p_center == 'hss_age_1020_FIQ_89':
    p_fold = 2
p_max_add_control = 0        
if p_center == 'hss_FIQ_89_bal':
        p_max_add_control += 34
elif p_center == 'hss_age_1020_FIQ_89_bal':
        p_max_add_control += 28       
if p_mode ==  'baseline sub-sample' or p_mode ==  'homogeneous sub-sample' :
    p_ss_length = subsamples_length[p_center]

In [None]:
print("*****List of parameters****")
print('Brain atlas:' ,p_ROI)
print('p_feature: ',p_feature)
print('p_center: ',p_center)
print('p_mode: ',p_mode)
if p_mode == 'site':
    print('Site: ',p_center)
else:
    print('Sub-sample: ',p_center)
    print('Sub-sample length: ',p_ss_length)
print('Augmentation:',p_augmentation)
print('Shuffle:',p_shuffle)
print('p_fold: ',p_fold)
print('p_max_add_control: ',p_max_add_control)

In [None]:
# data_files_path: path to ABIDE data, input_data_path: path to features data,
# data_phenotypic_path: path to phenotypic file
# Please update data_path
data_path = '~/abide_fmri_preprocessed/'
data_phenotypic_path = data_path+'Phenotypic_V1_0b_preprocessed1.csv'
data_files_path = data_path+ 'rois_'+p_ROI+'/'
input_data_path = data_path+ 'rois_'+p_ROI+'_input/'
print('data_phenotypic_path ',data_phenotypic_path)
print('data_files_path', data_files_path)
print('input_data_path', input_data_path)

### Downloading feature and eigenvalue data

In [None]:
f =open(input_data_path+p_feature+'_feature_file_'+p_ROI+'.pkl', 'rb')
feat_data = pickle.load(f)
f.close
print('file for feature data downloaded:',p_feature+'_feature_file_'+p_ROI+'.pkl')
f = open(input_data_path+'eig_data_'+p_feature+'_feature_file_'+p_ROI+'.pkl', 'rb')
eig_data = pickle.load(f)
f.close
print('file for eig data downloaded: ','eig_data_',p_feature+'_feature_file_'+p_ROI+'.pkl')


### Instances of the MultiSiteData (MSD), the HelperFunctions(HF), and  DataLoader (DL) classes

In [None]:
msd = MSD(data_phenotypic_path,data_files_path)
hf = HF(data_phenotypic_path,data_files_path,feat_data,eig_data)

In [None]:
# compute subjects ids and phenotypic data
msd.get_subjects_id()
subjects_id = msd.subjects_id
msd.get_phenotypic_data()
labels = msd.labels
#compute number of features
num_feat = len(feat_data[subjects_id[0]][0])
print('num_feat:',num_feat)
#length of eig_data
num_dim = len(eig_data[subjects_id[0]]['eigvals'])
print('num_dim:',num_dim)

In [None]:
dl = DL(data_phenotypic_path,data_files_path,feat_data,eig_data,num_dim)

### Computing subsamples and centers_dict

In [None]:
if p_mode == 'baseline sub-sample' or p_mode == 'site':
    hf.get_centers_dict()
    centers_dict =  hf.centers_dict
if p_mode == 'homogeneous sub-sample' or p_mode == 'baseline sub-sample':
    ss = SS(data_phenotypic_path,data_files_path,p_center,p_max_add_control,p_ss_length)
    if p_mode == 'homogeneous sub-sample':
        ss.get_hss() 
        subsample = np.array (ss.subsample)
        centers_dict = ss.centers_dict
    elif  p_mode == 'baseline sub-sample':
        ss.get_bss() 
        subsample = np.array (ss.subsample)

## Defining training functions

In [None]:
def train(model, epoch, train_loader, p_bernoulli=None, mode='both',lam_factor=1.0):         
    model.train()
    train_losses = []
    for i,(batch_x,batch_y) in enumerate(train_loader):
        if len(batch_x) != batch_size:
            continue
        if p_bernoulli is not None:
            if i == 0:
                p_tensor = torch.ones_like(batch_x).to(device)*p_bernoulli
            rand_bernoulli = torch.bernoulli(p_tensor).to(device)

        data, target = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()

        if mode in ['both', 'ae']:
            if p_bernoulli is not None:
                rec_noisy, _ = model(data*rand_bernoulli, False)
                loss_ae = criterion_ae(rec_noisy, data) / len(batch_x)
            else:
                rec, _ = model(data, False)
                loss_ae = criterion_ae(rec, data) / len(batch_x)

        if mode in ['both', 'clf']:
            rec_clean, logits = model(data, True)
            loss_clf = criterion_clf(logits, target)

        if mode == 'both':
            loss_total = loss_ae + lam_factor*loss_clf
            train_losses.append([loss_ae.detach().cpu().numpy(), 
                                 loss_clf.detach().cpu().numpy()])
        elif mode == 'ae':
            loss_total = loss_ae
            train_losses.append([loss_ae.detach().cpu().numpy(), 
                                 0.0])
        elif mode == 'clf':
            loss_total = loss_clf
            train_losses.append([0.0, 
                                 loss_clf.detach().cpu().numpy()])

        loss_total.backward()
        optimizer.step()
    #print('train_losses',train_losses)
    return train_losses       

### Defining ASD-DiagNet parameters and similarity function (sim_function) 

In [None]:
num_epochs = 50
batch_size = 8
learning_rate_ae, learning_rate_clf = 0.0001, 0.0001    
p_bernoulli = None
use_dropout = True    
augmentation = p_augmentation
aug_factor = 2
num_neighbs = 5
lim4sim = 2 
sim_function = functools.partial(hf.cal_similarity, lim=lim4sim)
run = False

### Computing classification scores with ASD-DiagNet for ABIDE sites, or for a homogeneous sub-sample, or for a baseline sub-sample. 

In [None]:
print ('p_center:',p_center)
print ('p_ROI: ', p_ROI )
print ('p_feature: ',p_feature)
print('num_epochs: ', num_epochs)
print ('p_fold: ',p_fold)
print ('shuffle: ',p_shuffle)
print('augmentation: ', augmentation, 'aug_factor: ', aug_factor, 
          'num_neighbs: ', num_neighbs, 'lim4sim: ', lim4sim,
          'number of features: ',num_feat)
    
if  p_mode == 'site':           
    subj_id_list = np.array(centers_dict[p_center])
            
    output_name = 'results_site_diag.csv'    
    results = open(output_name, 'a')
    print('Result will written in {0}'.format(output_name))    
    results.write('##########################################################################\n'+
                   'feature: '+str(p_feature) + '-p_center: '+str(p_center) + 
                    '-p_shuffle: '+str(p_shuffle) +
                          ','+'Ac'+','+'Se'+','+'Sp'+'\n')
    results.close()
    run = True
    
elif p_mode == 'homogeneous sub-sample':    
    print ('homogeneous subsample ',p_center,':',subsample)            
    length = subsample.shape[0]
    print('length subsample:', length)
    hf.get_subj_id_list(length,subsample,centers_dict)
    subj_id_list = np.array(hf.subj_id_list)
            
    output_name = 'results_hss_diag.csv'    
    results = open(output_name, 'a')    
    print('Result will written in {0}'.format(output_name)) 
    results.write('##########################################################################\n'+
                   'feature: '+str(p_feature) + '-p_center: '+str(p_center) + 
                    '-p_shuffle: '+str(p_shuffle) +
                          ','+'Ac'+','+'Se'+','+'Sp'+'\n')
    results.close()
    run = True
    
elif  p_mode == 'baseline sub-sample':
    print ('baseline subsample ',p_center,':',subsample)       
    length = subsample.shape[0]
    print('length subsample:', length)
    hf.get_subj_id_list(length,subsample,centers_dict)
    subj_id_list = np.array(hf.subj_id_list)
            
    output_name = 'results_bss_diag.csv'    
    results = open(output_name, 'a')    
    print('Result will written in {0}'.format(output_name)) 
    results.write('##########################################################################\n'+
                   'feature: '+str(p_feature) + '-p_center: '+str(p_center) + 
                    '-p_shuffle: '+str(p_shuffle) +
                          ','+'Ac'+','+'Se'+','+'Sp'+'\n')
    results.close()
    run = True
    
#compute the classification scores    
if run: 
    hf.get_number_subjects(subj_id_list)
    num_control_subj, num_autism_subj = hf.control_autism
    num_subjects = len(subj_id_list)
    print( 'subjects:',num_subjects, ',control:',num_control_subj,
          ',autism:',num_autism_subj)    
    labels_list = np.array([labels[subj] for subj in subj_id_list])
    
    all_rp_res=[]
    kk=0 
    repeat = 10   
    total_time = 0    
    start_time =time.time()
    pbar = pyprind.ProgBar(repeat)
    for rp in range(repeat): 
        crossval_res_kol=[]
        if p_shuffle:
            kf = StratifiedKFold(n_splits=p_fold, random_state=1, shuffle=True)
            np.random.shuffle(subj_id_list)
            labels_list = np.array([labels[subj] for subj in subj_id_list])
        else:
            kf = StratifiedKFold(n_splits=p_fold)    
        for kk,(train_index, test_index) in enumerate(kf.split(subj_id_list, labels_list)):
            train_samples, test_samples = subj_id_list[train_index], subj_id_list[test_index]
            verbose = (True if (kk == 0) else False)
            thr_regs = 0.25
            reg_num =int(num_feat*thr_regs)
            hf.get_regs_inds(train_samples,reg_num)
            regions_inds = np.array(hf.regions_inds)
            num_inpp = len(regions_inds)
            lat_mult = 0.5
            n_lat = int(num_inpp*lat_mult)
            if kk == 0 and rp == 0:
                print('thr_regs,lat_mult,num_inpp,n_lat: ',thr_regs,lat_mult,num_inpp,n_lat)
            train_loader=dl.get_loader(data=feat_data, samples_list=train_samples, 
                                    batch_size=batch_size, mode='train',
                                    augmentation=augmentation, aug_factor=aug_factor, 
                                    num_neighbs=num_neighbs,eig_data=eig_data, 
                                    similarity_fn=sim_function, 
                                    verbose=verbose,regions=regions_inds)                                    
            test_loader=dl.get_loader(data=feat_data, samples_list=test_samples, 
                                   batch_size=batch_size, mode='test', augmentation=False, 
                                   verbose=verbose,regions=regions_inds)
            model = MTA(tied=True, num_inputs=num_inpp, num_latent=n_lat, use_dropout=use_dropout)
            model.to(device)
            criterion_ae = nn.MSELoss(reduction='sum') #MSE: Mean Square Error
            criterion_clf = nn.BCEWithLogitsLoss()     #BCE: Binary Cross Entropy
            optimizer = optim.SGD([{'params': model.fc_encoder.parameters(), 'lr': learning_rate_ae},
                                   {'params': model.classifier.parameters(), 'lr': learning_rate_clf}],
                                    momentum=0.9)
            for epoch in range(1, num_epochs+1):
                if epoch <= 20:
                    train_losses = train(model, epoch, train_loader, p_bernoulli, mode='both')
                else:
                    train_losses = train(model, epoch, train_loader, p_bernoulli, mode='clf')            
            res_mlp = hf.test(model, criterion_ae, test_loader, 
                              eval_classifier=True,device=device)
            print(res_mlp)
            crossval_res_kol.append(res_mlp)
       
        r = np.mean(np.array(crossval_res_kol),axis = 0)
        all_rp_res.append(r) 
                             
        hf.output_repeat_results(rp,r,start_time)        
        pbar.update() 
        
    hf.output_results(all_rp_res,repeat,output_name,
                      p_center,p_mode,p_ROI,p_feature)