# 01 - Running the Main Regressions


In [1]:
import numpy as np
import pandas as pd

import pickle


from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import KFold

from scipy import stats

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import mutual_info_regression

from sklearn.pipeline import Pipeline

import xgboost as xgb

## Reading the data

Please note that we cannot share the original data used in the paper so we provide example files in `example_data/` folder in order to execute this notebook.

In [2]:
# Some people don't have node-wise features, so selecting only those
all_people = [100206, 100307, 100408, 100610, 101107, 101309, 101915, 102008, 102311, 102513, 102614, 102715, 102816, 103010, 103111, 103212, 103414, 103515, 103818, 104012, 104416, 104820, 105014, 105115, 105216, 105620, 105923, 106016, 106319, 106521, 106824, 107018, 107321, 107422, 107725, 108020, 108121, 108222, 108323, 108525, 108828, 109123, 109830, 110007, 110411, 110613, 111009, 111211, 111312, 111413, 111514, 111716, 112112, 112314, 112516, 112920, 113215, 113316, 113619, 113922, 114217, 114318, 114419, 114621, 114823, 115017, 115219, 115320, 115724, 115825, 116524, 116726, 117021, 117122, 117324, 117930, 118023, 118124, 118225, 118528, 118730, 118831, 118932, 119025, 119126, 120111, 120212, 120414, 120515, 120717, 121416, 121618, 121921, 122317, 122620, 122822, 123117, 123420, 123521, 123723, 123824, 123925, 124220, 124422, 124624, 124826, 125222, 125424, 125525, 126325, 126426, 126628, 127226, 127327, 127630, 127731, 127832, 127933, 128026, 128127, 128632, 128935, 129028, 129129, 129331, 129634, 130013, 130114, 130316, 130417, 130518, 130619, 130720, 130821, 130922, 131217, 131419, 131722, 131823, 131924, 132017, 132118, 133019, 133625, 133827, 133928, 134021, 134223, 134324, 134425, 134627, 134728, 134829, 135124, 135225, 135528, 135629, 135730, 135932, 136126, 136227, 136631, 136732, 136833, 137027, 137128, 137229, 137431, 137532, 137633, 137936, 138130, 138231, 138332, 138534, 138837, 139233, 139435, 139637, 139839, 140117, 140319, 140824, 140925, 141119, 141422, 141826, 142828, 143224, 143325, 143426, 143830, 144125, 144226, 144428, 144731, 144832, 144933, 145127, 145632, 145834, 146129, 146331, 146432, 146533, 146735, 146836, 146937, 147030, 147636, 147737, 148032, 148133, 148335, 148436, 148840, 148941, 149236, 149337, 149539, 149741, 149842, 150625, 150726, 150928, 151223, 151324, 151425, 151526, 151627, 151728, 151829, 151930, 152225, 152427, 152831, 153025, 153126, 153227, 153429, 153631, 153732, 153833, 153934, 154229, 154330, 154431, 154532, 154734, 154835, 154936, 155635, 155938, 156031, 156233, 156334, 156435, 156536, 156637, 157336, 157437, 157942, 158035, 158136, 158338, 158540, 158843, 159138, 159239, 159340, 159744, 160123, 160729, 160830, 161327, 161630, 161731, 161832, 162026, 162228, 162329, 162733, 162935, 163129, 163331, 163432, 163836, 164030, 164131, 164636, 164939, 165032, 165436, 165638, 165840, 165941, 166438, 166640, 167036, 167238, 167440, 167743, 168139, 168240, 168341, 168745, 168947, 169040, 169343, 169444, 169545, 169949, 170631, 171330, 171532, 171633, 172029, 172130, 172332, 172433, 172534, 172938, 173334, 173435, 173536, 173637, 173738, 173839, 173940, 174437, 174841, 175035, 175136, 175237, 175338, 175439, 175540, 175742, 176037, 176239, 176441, 176542, 176744, 176845, 177140, 177241, 177645, 177746, 178142, 178243, 178647, 178748, 178849, 178950, 179245, 179346, 180129, 180230, 180432, 180533, 180735, 180836, 180937, 181131, 181232, 181636, 182032, 182436, 182739, 182840, 183034, 185038, 185139, 185341, 185442, 185846, 185947, 186040, 186141, 186444, 186545, 186848, 187143, 187345, 187547, 187850, 188145, 188347, 188448, 188549, 188751, 189349, 189450, 189652, 190031, 191033, 191235, 191336, 191437, 191841, 191942, 192035, 192136, 192237, 192439, 192540, 192641, 192843, 193239, 193845, 194140, 194443, 194645, 194746, 194847, 195041, 195445, 195647, 195849, 195950, 196144, 196346, 196750, 197348, 197550, 198047, 198249, 198350, 198451, 198653, 198855, 199150, 199251, 199352, 199453, 199655, 199958, 200008, 200109, 200311, 200513, 200614, 200917, 201111, 201414, 201515, 201818, 202113, 202719, 203418, 203923, 204016, 204218, 204319, 204420, 204521, 204622, 205119, 205220, 205725, 205826, 206222, 206323, 206525, 206727, 206828, 206929, 207123, 207426, 208024, 208125, 208226, 208327, 208630, 209127, 209228, 209329, 209834, 209935, 210011, 210112, 210415, 210617, 211114, 211215, 211316, 211417, 211619, 211720, 211821, 211922, 212015, 212116, 212217, 212318, 212419, 212823, 213017, 213421, 213522, 214019, 214221, 214423, 214524, 214625, 214726, 217126, 217429, 219231, 220721, 221319, 223929, 224022, 227432, 227533, 228434, 231928, 233326, 236130, 237334, 238033, 239944, 245333, 246133, 248339, 249947, 250427, 250932, 251833, 255639, 255740, 256540, 257542, 257845, 257946, 263436, 268749, 268850, 270332, 274542, 275645, 280739, 280941, 281135, 283543, 285345, 285446, 286347, 286650, 287248, 289555, 290136, 293748, 295146, 297655, 298051, 298455, 299154, 299760, 300618, 300719, 303119, 303624, 304020, 304727, 305830, 307127, 308129, 308331, 309636, 310621, 311320, 314225, 316633, 316835, 318637, 320826, 321323, 322224, 325129, 329440, 329844, 330324, 333330, 334635, 336841, 339847, 341834, 342129, 346137, 346945, 348545, 349244, 350330, 352132, 352738, 353740, 356948, 358144, 360030, 361234, 361941, 365343, 366042, 366446, 368551, 368753, 371843, 376247, 377451, 378756, 378857, 379657, 380036, 381038, 381543, 382242, 385046, 385450, 386250, 387959, 389357, 390645, 391748, 392447, 392750, 393247, 393550, 394956, 395251, 395756, 395958, 397154, 397760, 397861, 406432, 406836, 412528, 414229, 415837, 422632, 424939, 429040, 432332, 433839, 436239, 436845, 441939, 445543, 448347, 449753, 453441, 456346, 459453, 465852, 467351, 475855, 479762, 480141, 481951, 485757, 486759, 495255, 497865, 499566, 500222, 506234, 510326, 512835, 513736, 517239, 519950, 520228, 522434, 523032, 524135, 525541, 529549, 529953, 530635, 531536, 536647, 540436, 541943, 545345, 547046, 548250, 553344, 555348, 555651, 557857, 559053, 561242, 561444, 562345, 562446, 565452, 566454, 567052, 567961, 568963, 570243, 571144, 572045, 573249, 573451, 576255, 579665, 579867, 580044, 580347, 580650, 580751, 581349, 581450, 583858, 585256, 585862, 586460, 587664, 588565, 592455, 594156, 597869, 598568, 599065, 599469, 599671, 601127, 604537, 609143, 611938, 613538, 614439, 615744, 616645, 617748, 618952, 620434, 622236, 623844, 626648, 627549, 627852, 628248, 633847, 638049, 645450, 645551, 647858, 654350, 654754, 656253, 656657, 657659, 660951, 663755, 664757, 665254, 667056, 668361, 671855, 672756, 673455, 677766, 677968, 679568, 679770, 680250, 680957, 683256, 685058, 686969, 690152, 693764, 695768, 700634, 702133, 704238, 705341, 706040, 707749, 709551, 713239, 715041, 715647, 715950, 720337, 724446, 725751, 727553, 727654, 729254, 729557, 731140, 732243, 734045, 735148, 737960, 742549, 744553, 748258, 749058, 749361, 751348, 753150, 753251, 756055, 759869, 761957, 765056, 767464, 769064, 770352, 771354, 773257, 779370, 782561, 783462, 784565, 788876, 789373, 792564, 792766, 792867, 793465, 800941, 802844, 803240, 810843, 812746, 814649, 816653, 818859, 820745, 825048, 826353, 826454, 833148, 833249, 835657, 837560, 837964, 841349, 843151, 844961, 845458, 849264, 849971, 852455, 856766, 856968, 857263, 859671, 861456, 865363, 867468, 870861, 871762, 871964, 872158, 872562, 872764, 873968, 877168, 877269, 880157, 882161, 885975, 887373, 889579, 891667, 894067, 894673, 894774, 896778, 896879, 898176, 899885, 901038, 901139, 901442, 904044, 907656, 910241, 910443, 912447, 917255, 917558, 919966, 922854, 923755, 930449, 932554, 937160, 947668, 951457, 952863, 955465, 957974, 958976, 959574, 965367, 965771, 966975, 978578, 979984, 983773, 984472, 987983, 990366, 991267, 992673, 992774, 993675, 994273, 996782]

# Storing node-wise features in node_features
node_features = pd.read_csv("../example_data/node_wise_features.csv")
node_features.set_index("Subjects", inplace=True)
node_features = node_features.loc[all_people]

# How many features we actually have
no_features = node_features.shape[1]

# Just having a look to the resulting DataFrame
node_features

Unnamed: 0_level_0,fs_l_bankssts_myel,fs_l_caudalanteriorcingulate_myel,fs_l_caudalmiddlefrontal_myel,fs_l_cuneus_myel,fs_l_entorhinal_myel,fs_l_fusiform_myel,fs_l_inferiorparietal_myel,fs_l_inferiortemporal_myel,fs_l_isthmuscingulate_myel,fs_l_lateraloccipital_myel,...,fs_r_temporalpole_gauscurv,fs_r_temporalpole_foldind,fs_r_temporalpole_curvind,fs_r_transversetemporal_thck,fs_r_transversetemporal_area,fs_r_transversetemporal_grayvol,fs_r_transversetemporal_meancurv,fs_r_transversetemporal_gauscurv,fs_r_transversetemporal_foldind,fs_r_transversetemporal_curvind
Subjects,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
100206,68.3802,84.6834,27.1701,81.5622,15.4853,26.4674,37.7649,5.2173,73.7679,35.1480,...,16.4163,90.7641,0.6600,78.7549,99.7925,91.7324,76.9609,41.1895,33.5080,57.8402
100307,62.4157,70.4599,53.1805,4.8509,89.1462,86.7884,84.3975,71.3113,85.3448,8.1502,...,1.5649,67.1979,24.4786,87.6043,36.0189,74.0866,47.5527,33.1548,61.2061,34.1861
100408,52.3308,94.0239,78.9857,21.0805,52.3540,37.1030,92.0178,31.1592,54.9603,48.9203,...,85.2391,73.8248,78.9037,0.7619,31.9986,52.5146,44.4433,28.8116,83.4564,32.9102
100610,74.5052,67.3638,23.2021,53.7091,68.0855,92.9996,90.1983,49.6079,12.6569,53.3050,...,65.4194,4.6989,97.9936,69.5503,35.4276,3.7882,98.1156,43.8501,83.3597,40.8268
101107,81.6852,54.5506,22.5713,46.0172,91.5473,91.1040,69.4039,50.9502,62.8165,40.1263,...,22.1308,81.5856,1.1503,1.5984,2.3075,62.8794,68.4576,21.0574,35.0529,1.0676
101309,75.7167,2.4871,71.8991,81.9042,21.4685,33.0208,16.1911,87.9001,99.9863,84.8817,...,36.2613,91.7308,68.2630,91.2377,53.3133,49.5246,60.1782,3.6767,91.6736,6.3599
101915,71.8100,31.4061,11.9436,7.7799,22.2760,7.1954,42.9989,97.9421,25.2908,85.3087,...,9.1849,61.0502,46.3617,19.8850,31.0203,39.2119,6.1010,69.8063,79.5966,46.5499
102008,53.0536,28.6786,33.8731,7.3382,19.1064,82.7264,71.0988,25.7998,8.5103,63.6334,...,99.2972,64.6250,90.5323,80.7202,8.2576,84.0195,2.7313,27.9080,43.8133,31.7952
102311,3.0686,35.1263,42.0029,31.5949,54.4138,43.9702,90.3930,43.9115,63.1818,44.2193,...,21.2655,97.6428,43.6382,43.6117,37.1245,64.2670,19.7846,29.8667,32.1215,72.5157
102513,48.7164,29.8567,37.7470,75.3322,48.6451,36.5460,51.3165,58.9368,82.5574,15.4167,...,45.8396,98.2710,12.5781,1.3336,92.9276,40.4374,7.9632,24.1986,69.4835,65.8238


In [3]:
# Getting the 9 factors for each person
fac_data = pd.read_csv("../example_data/factors_nine.csv")
fac_data.set_index("Subject", inplace=True)
fac_data

Unnamed: 0_level_0,FAC1,FAC2,FAC3,FAC4,FAC5,FAC6,FAC7,FAC8,FAC9
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
100206,-2.19322,-3.24582,-0.35650,-0.19601,-3.27659,1.04789,-3.29911,-2.54841,-2.13112
100307,0.10101,4.40783,-2.28429,3.43902,2.80362,-3.04794,2.34130,-0.66747,0.84274
100408,-0.83148,-4.72360,-0.68458,1.37980,4.53747,0.33145,-4.69920,-2.34972,2.07273
100610,-3.14493,-4.11641,-4.05372,4.44089,1.99478,-2.67238,2.46204,-4.28783,-3.20867
101107,-4.96598,-4.03152,-3.51658,1.45762,-3.86265,-2.62140,1.51533,3.55707,1.57449
101309,-0.78382,0.43219,0.30508,0.50734,0.17685,4.80278,2.25869,-0.66330,-0.47164
101410,-3.71168,0.12864,2.85109,4.08283,0.40834,0.44076,3.41883,-1.25552,-1.89356
101915,-4.23780,3.73291,2.10656,-2.81845,2.57776,-4.23772,1.14446,-2.95911,0.41416
102008,-0.88916,1.57972,-4.12997,2.53589,-0.02123,2.27484,3.83379,4.23613,2.90230
102311,3.85916,0.01410,2.85137,2.43636,3.67560,-2.78860,-1.02555,1.30852,-0.20688


## The main loop

This is the main loop to predict all the nine factors. For simplicity of this notebook, it will only be executed for the first factor. Edit `values_to_predict` if you want to predict more factors.

For each factor there is a dictionary defined in the code as `all_dic` in which the results of all the prediction steps will be stored in, and saved in disk using `pickle`. This dictionary will have all the information necessary for the analysis that will be conducted in the next notebook.

Depending on the hyperparameter search parameters that you define, this might take quite a while to conclude.

In [4]:
values_to_predict = ['FAC1']

# Uncomment the following line to predict all nine
#values_to_predict = ['FAC1', 'FAC2', 'FAC3', 'FAC4', 'FAC5', 'FAC6', 'FAC7', 'FAC8', 'FAC9']

In [5]:
for column in values_to_predict:
    
    print("###################")
    print("Running code for " + column + "...")
    
    # Where everything will be stored
    all_dic = {}
 
    # Features and the factor's values in the same dataframe
    all_df = node_features.join(fac_data[column])
    
    # There might be some NaNs in the factors, so not considering those
    filtered_df = all_df.loc[~np.isnan(all_df[column])]

    # Filtering out the data itself (X) and the labels (y)
    X = filtered_df.iloc[:, 0:no_features]
    y = filtered_df.loc[:, column]
    all_dic['X'] = X
    all_dic['y'] = y

    # A parameter grid search for XGBoost
    parameters = {'xgb__nthread':[1],
      'xgb__objective':['reg:linear'],
      'xgb__learning_rate': [0.005, 0.05, .5], #so called 'eta' value
      'xgb__max_depth': [3, 5, 7],
      'xgb__min_child_weight': [7, 10, 30, 50],
      'xgb__subsample': [0.1, 0.3, 0.7, 1.0],
      'xgb__colsample_bytree': [0.4, 0.6, 1.0],
      'xgb__colsample_bynode': [0.4, 0.6, 1.0],
      'xgb__colsample_bylevel': [0.4, 0.6, 1.0],
      'xgb__n_estimators': [750],
      'xgb__reg_lambda': [1, 3, 5, 10, 50],
      'xgb__reg_alpha': [1, 3, 5, 10, 50],
      'xgb__gamma' : [1, 3, 5, 10]}

    # Information about how the nexted cross-validation will be. Usage of random_state guarantees that the folds will always be divided the same way when executing this same code
    inner_cv = KFold(n_splits=3, shuffle=True, random_state=0)
    outer_cv = KFold(n_splits=5, shuffle=True, random_state=0)


    # Necessary to calculate the averaged scores over the 5 outer folds
    score_mse = 0.0
    score_mae = 0.0
    score_r2 = 0.0
    score_per = 0.0
    ind = 0 # Index to identify the fold
    
    # Main loop, splitting the data into train/test sets using KFold
    for train_index, test_index in outer_cv.split(X):

        selector = SelectKBest(mutual_info_regression, k=450)
        clf_xgb = xgb.XGBRegressor()
        pipe = Pipeline(steps=[('kbest', selector), ('xgb', clf_xgb)])

        # There are too many parameters to search for, so randomised search is used
        clf = RandomizedSearchCV(pipe, 
                                param_distributions = parameters,
                                n_iter = 80, # Increase this value to try more combinations
                                scoring = 'neg_mean_squared_error',
                                cv=inner_cv,
                                return_train_score=True,
                                iid=False,
                                n_jobs = -1)
        # A dictionary used to store the information about this fold
        dic_split = {}
        
        # Storing the train and test indexes that allow the construction of train/test sets
        dic_split['train_index'] = train_index
        dic_split['test_index'] = test_index
        
        X_train, X_test = X.iloc[train_index].copy(), X.iloc[test_index].copy()
        y_train, y_test = y.iloc[train_index].copy(), y.iloc[test_index].copy()
        
        # No need to normalise input for Xgboost, so just fitting the data
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        
        # Calculating the metrics for this fold
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        r2  = r2_score(y_test, y_pred)
        pearson_r = stats.pearsonr(y_pred, y_test)

        score_r2 += r2
        score_mae += mae
        score_mse += mse
        score_per += pearson_r[0]

        # Storing all the information about this fold
        dic_split['y_pred'] = y_pred
        dic_split['r2'] = r2
        dic_split['mae'] = mae
        dic_split['mse'] = mse
        dic_split['pearson_r'] = pearson_r
        dic_split['clf_obj'] = clf
        dic_split['kbests'] = clf.best_estimator_.named_steps['kbest'].get_support()

        # Storing this fold's dictionary in the main one
        all_dic['split'+str(ind)] = dic_split

        ind += 1

    # Averaging the scores over the 5 folds
    score_mse /= 5
    score_mae /= 5
    score_r2 /= 5
    score_per /= 5

    print("--- Results:")
    print("Pearson r: " + str(score_per))
    print("MAE: " + str(score_mae))
    print("MSE: " + str(score_mse))
    print("r^2: " + str(score_r2))
    
    # Storing the results in the main dictionary
    all_dic['mae'] = score_mae
    all_dic['mse'] = score_mse
    all_dic['r2'] = score_r2
    all_dic['pearson_r'] = score_per               

    # To save the results and be further analysed
    pickle.dump( all_dic, open( "dic_all_" + column + "_results.pkl", "wb" ) )

###################
Running code for FAC1...
--- Results:
Pearson r: -0.005459543825447927
MAE: 2.5903732699746618
MSE: 8.835178538343232
r^2: -0.02094532240322966
