In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
# Paths
root_p = '/home/surchs/sim_big/PROJECT/abide_hps/'
# Pheno
abide1_p = os.path.join(root_p, 'pheno', 'abide_1_complete.csv')
# Data
ct_p = os.path.join(root_p, 'ct')
fc_p = os.path.join(root_p, 'fc')
# File templates
ct_t = '{}+{:07}_{}+{}_native_rms_rsl_tlaplace_30mm_left.txt'
fc_t = 'fmri_{:07}_session_1_run1.nii.gz'
# Out_path
train_p = os.path.join(root_p, 'pheno', 'ABIDE_TRAIN.csv')
validate_p = os.path.join(root_p, 'pheno', 'ABIDE_VALIDATE.csv')

In [3]:
pheno = pd.read_csv(abide1_p)
# Add a sample matching column for site
pheno['match_Site'] = pheno['Site']
pheno.replace({'match_Site':{'Leuven_1':'Leuven', 'Leuven_2':'Leuven'}}, inplace=True)

In [4]:
# Find available data
sub_ind_lenient = [row['ct_available'] 
                   and row['fc_available'] 
                   and row['Ratings'] > 1 
                   and not row['status']=='Fail'
                   for rid, row in pheno.iterrows()]

sub_ind_strict = [row['ct_available'] 
                  and row['fc_available'] 
                  and row['Ratings'] > 2 
                  and row['status']=='OK'
                  for rid, row in pheno.iterrows()]

In [5]:
np.sum(sub_ind_strict)

184

In [6]:
lenient = pheno[sub_ind_lenient]
strict = pheno[sub_ind_strict]

In [7]:
lenient.groupby('DX_GROUP')['match_Site'].value_counts()

DX_GROUP  match_Site
Autism    NYU           72
          USM           46
          OHSU          36
          UCLA_1        35
          Leuven        28
          Pitt          25
          Trinity       20
          KKI           19
          UM_1          14
          SBL           13
          Yale          13
          Olin          11
          Caltech        9
          UCLA_2         9
          CMU_b          8
          CMU_a          6
          MaxMun_a       6
          MaxMun_b       6
          SDSU           5
          MaxMun_d       4
          MaxMun_c       1
          Stanford       1
          UM_2           1
Control   NYU           96
          USM           41
          OHSU          33
          KKI           31
          UCLA_1        30
          Leuven        28
          UM_1          27
          Trinity       24
          Pitt          21
          SDSU          18
          Caltech       16
          MaxMun_a      12
          Olin          11
       

In [9]:
# Find sites that have at least 4 of ASD and TDC each
good_sites = ['NYU', 'USM', 'OHSU', 'UCLA_1', 'Pitt', 'Trinity', 'KKI', 'Leuven', 'UM_1', 'SBL', 'Olin', 'Caltech']

In [10]:
lenient_ind = [True if row['match_Site'] in good_sites else False for rid, row in lenient.iterrows()]

In [11]:
sample = lenient[lenient_ind]

In [12]:
sample.shape

(696, 97)

In [13]:
train, test = train_test_split(sample.index.values, test_size=0.5, train_size=0.5, random_state=2, stratify=sample[['match_Site', 'DX_GROUP']].values)

In [14]:
train_sample = sample.loc[train]

In [15]:
test_sample = sample.loc[test]

In [16]:
train_sample.shape

(348, 97)

In [17]:
test_sample.shape

(348, 97)

In [22]:
train_sample.groupby('DX_GROUP')['Site'].value_counts()

DX_GROUP  Site    
Autism    NYU         36
          USM         23
          OHSU        18
          UCLA_1      18
          Pitt        12
          KKI         10
          Trinity     10
          Leuven_1     7
          Leuven_2     7
          UM_1         7
          Olin         6
          SBL          6
          Caltech      4
Control   NYU         48
          USM         20
          OHSU        16
          KKI         15
          UCLA_1      15
          UM_1        14
          Trinity     12
          Pitt        11
          Caltech      8
          Leuven_1     7
          Leuven_2     7
          Olin         6
          SBL          5
Name: Site, dtype: int64

In [21]:
test_sample.groupby('DX_GROUP')['Site'].value_counts()

DX_GROUP  Site    
Autism    NYU         36
          USM         23
          OHSU        18
          UCLA_1      17
          Pitt        13
          Trinity     10
          KKI          9
          Leuven_1     7
          Leuven_2     7
          SBL          7
          UM_1         7
          Caltech      5
          Olin         5
Control   NYU         48
          USM         21
          OHSU        17
          KKI         16
          UCLA_1      15
          UM_1        13
          Trinity     12
          Pitt        10
          Caltech      8
          Leuven_2     8
          Leuven_1     6
          Olin         5
          SBL          5
Name: Site, dtype: int64

In [20]:
train_sample.to_csv(train_p, index=False)
test_sample.to_csv(validate_p, index=False)