In [34]:
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

# Read in the feature data and discover null values

In [35]:
# Read in the nodes.csv
nodes = pd.read_csv('data/nodes.csv')
nodes.head()

Unnamed: 0,subjectID,tractID,nodeID,rd,fa,cl,ad,md
0,subject_000,Left Thalamic Radiation,0,0.656032,0.183053,0.081921,0.875535,0.7292
1,subject_000,Left Thalamic Radiation,1,0.613308,0.247121,0.11548,0.909085,0.711901
2,subject_000,Left Thalamic Radiation,2,0.574612,0.306726,0.151766,0.944572,0.697932
3,subject_000,Left Thalamic Radiation,3,0.549868,0.343995,0.176124,0.966964,0.6889
4,subject_000,Left Thalamic Radiation,4,0.53019,0.373869,0.194396,0.985039,0.681806


In [36]:
# Let's take a look at the nodes table info
nodes.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 154000 entries, 0 to 153999
Data columns (total 8 columns):
subjectID    154000 non-null object
tractID      154000 non-null object
nodeID       154000 non-null int64
rd           152700 non-null float64
fa           152326 non-null float64
cl           152700 non-null float64
ad           152700 non-null float64
md           152700 non-null float64
dtypes: float64(5), int64(1), object(2)
memory usage: 9.4+ MB


In [37]:
# Hmm, there are some null values for the metrics,
# let's take a look at the rows that have null values
nodes.loc[nodes.isnull().any(axis='columns')]

Unnamed: 0,subjectID,tractID,nodeID,rd,fa,cl,ad,md
99,subject_000,Left Thalamic Radiation,99,0.574619,,0.080268,0.779967,0.643069
199,subject_000,Right Thalamic Radiation,99,0.571765,,0.127559,0.878160,0.673897
200,subject_000,Left Corticospinal,0,0.472894,,0.071492,0.677883,0.541224
300,subject_000,Right Corticospinal,0,0.528727,,0.116244,0.832071,0.629842
2099,subject_001,Left Thalamic Radiation,99,0.240981,,0.132384,0.570398,0.350787
2199,subject_001,Right Thalamic Radiation,99,0.372665,,0.170093,0.711544,0.485625
2200,subject_001,Left Corticospinal,0,0.148385,,0.091452,0.324785,0.207185
2300,subject_001,Right Corticospinal,0,0.331715,,0.080227,0.544015,0.402482
4099,subject_002,Left Thalamic Radiation,99,0.532312,,0.097986,0.786821,0.617148
4199,subject_002,Right Thalamic Radiation,99,0.593873,,0.060743,0.775232,0.654326


# Interpolate `NaN` values from nearby nodes

In [38]:
# We'd like to interpolate the missing values, but first we need to structure 
# the data frame so that it does not interpolate from other patients, tracts,
# or metrics. It should only interpolate from nearby nodes
# So we want the nodeID as the row index and all the other stuff as columns
# After that we can interpolate along each column
by_node_idx = pd.pivot_table(
    data=nodes.melt(id_vars=['subjectID', 'tractID', 'nodeID'], var_name='metric'),
    index='nodeID',
    columns=['metric', 'tractID', 'subjectID'],
    values='value'
)

# Let's look at the locations with null values under this new dataframe structure
by_node_idx.loc[by_node_idx.isnull().any(axis='columns'), by_node_idx.isnull().any(axis='rows')]

metric,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa
tractID,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Left Arcuate,Left Arcuate,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Corticospinal,...,Right Thalamic Radiation,Right Thalamic Radiation,Right Thalamic Radiation,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate
subjectID,subject_013,subject_039,subject_061,subject_070,subject_051,subject_061,subject_019,subject_060,subject_061,subject_000,...,subject_073,subject_075,subject_076,subject_019,subject_026,subject_042,subject_056,subject_062,subject_065,subject_073
nodeID,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
0,,0.204551,,,0.179262,0.177054,0.180685,0.19007,0.187525,,...,0.230178,0.191712,0.189012,,0.179535,0.178694,0.18361,,0.18625,0.176448
1,0.274187,0.262941,0.505873,0.271225,0.236669,0.236865,0.210511,0.20222,0.222604,0.341655,...,0.270223,0.209479,0.266586,0.236529,0.221781,0.224746,0.218422,0.216726,0.213763,0.223417
2,0.32625,0.307926,0.482015,0.319023,0.29887,0.288474,0.233045,0.214528,0.262054,0.380395,...,0.314196,0.226802,0.326556,0.301793,0.257182,0.27158,0.249845,0.248501,0.245365,0.267737
3,0.337527,0.33255,0.458337,0.34163,0.352558,0.330694,0.247546,0.223219,0.29099,0.4023,...,0.345945,0.239311,0.361582,0.354039,0.272753,0.301833,0.279195,0.268104,0.273859,0.298692
4,0.341623,0.340057,0.453747,0.350528,0.392831,0.378554,0.264519,0.228656,0.310104,0.42918,...,0.370762,0.24908,0.38817,0.370351,0.282364,0.32541,0.316253,0.278656,0.308146,0.320947
5,0.358933,0.343922,0.425955,0.369938,0.421998,0.423745,0.280566,0.232134,0.324181,0.468649,...,0.386055,0.258027,0.39656,0.384095,0.293963,0.34527,0.342079,0.282399,0.342367,0.341777
6,0.389969,0.37055,0.402147,0.392694,0.431251,0.442995,0.298984,0.23785,0.335728,0.511123,...,0.396626,0.267588,0.391463,0.397777,0.302578,0.367088,0.347042,0.286884,0.36962,0.361118
11,0.48132,0.476745,0.438012,0.464462,0.438958,0.51565,0.328839,0.285459,0.357349,0.59587,...,0.402196,0.312212,0.444821,0.436617,0.337558,0.412653,0.294103,0.341477,0.369376,0.376663
19,0.698855,0.655657,0.552387,0.636075,0.523916,0.502649,0.385963,0.320291,0.390599,0.665425,...,0.385826,0.328781,0.357297,0.403516,0.399058,0.379055,0.419434,0.372982,0.447271,0.47836
22,0.701255,0.693943,0.656692,0.699202,0.427661,0.539608,0.402883,0.363187,0.412225,0.652735,...,0.398033,0.342794,0.338804,0.384954,0.417771,0.378874,0.410862,0.402993,0.528836,0.489482


In [39]:
# We could use the built-in `.interpolate` method. This has some unexpected behavior when the NaN
# values are at the beginning or end of a series. For NaN values at the end of the series,
# it forward fills the most recent valid value. And for NaN values at the beginning of the series,
# it back fills the next valid value.
interpolated = by_node_idx.interpolate(method='linear', limit_direction='both')
interpolated.loc[by_node_idx.isnull().any(axis='columns'), by_node_idx.isnull().any(axis='rows')]

metric,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa
tractID,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Left Arcuate,Left Arcuate,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Corticospinal,...,Right Thalamic Radiation,Right Thalamic Radiation,Right Thalamic Radiation,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate
subjectID,subject_013,subject_039,subject_061,subject_070,subject_051,subject_061,subject_019,subject_060,subject_061,subject_000,...,subject_073,subject_075,subject_076,subject_019,subject_026,subject_042,subject_056,subject_062,subject_065,subject_073
nodeID,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
0,0.274187,0.204551,0.505873,0.271225,0.179262,0.177054,0.180685,0.19007,0.187525,0.341655,...,0.230178,0.191712,0.189012,0.236529,0.179535,0.178694,0.18361,0.216726,0.18625,0.176448
1,0.274187,0.262941,0.505873,0.271225,0.236669,0.236865,0.210511,0.20222,0.222604,0.341655,...,0.270223,0.209479,0.266586,0.236529,0.221781,0.224746,0.218422,0.216726,0.213763,0.223417
2,0.32625,0.307926,0.482015,0.319023,0.29887,0.288474,0.233045,0.214528,0.262054,0.380395,...,0.314196,0.226802,0.326556,0.301793,0.257182,0.27158,0.249845,0.248501,0.245365,0.267737
3,0.337527,0.33255,0.458337,0.34163,0.352558,0.330694,0.247546,0.223219,0.29099,0.4023,...,0.345945,0.239311,0.361582,0.354039,0.272753,0.301833,0.279195,0.268104,0.273859,0.298692
4,0.341623,0.340057,0.453747,0.350528,0.392831,0.378554,0.264519,0.228656,0.310104,0.42918,...,0.370762,0.24908,0.38817,0.370351,0.282364,0.32541,0.316253,0.278656,0.308146,0.320947
5,0.358933,0.343922,0.425955,0.369938,0.421998,0.423745,0.280566,0.232134,0.324181,0.468649,...,0.386055,0.258027,0.39656,0.384095,0.293963,0.34527,0.342079,0.282399,0.342367,0.341777
6,0.389969,0.37055,0.402147,0.392694,0.431251,0.442995,0.298984,0.23785,0.335728,0.511123,...,0.396626,0.267588,0.391463,0.397777,0.302578,0.367088,0.347042,0.286884,0.36962,0.361118
11,0.48132,0.476745,0.438012,0.464462,0.438958,0.51565,0.328839,0.285459,0.357349,0.59587,...,0.402196,0.312212,0.444821,0.436617,0.337558,0.412653,0.294103,0.341477,0.369376,0.376663
19,0.698855,0.655657,0.552387,0.636075,0.523916,0.502649,0.385963,0.320291,0.390599,0.665425,...,0.385826,0.328781,0.357297,0.403516,0.399058,0.379055,0.419434,0.372982,0.447271,0.47836
22,0.701255,0.693943,0.656692,0.699202,0.427661,0.539608,0.402883,0.363187,0.412225,0.652735,...,0.398033,0.342794,0.338804,0.384954,0.417771,0.378874,0.410862,0.402993,0.528836,0.489482


In [40]:
# Instead, we may want to interpolate NaN values with extrapolation at the end 
# of the node range. But, pandas does not currently support extrapolation
# See this issue: https://github.com/pandas-dev/pandas/issues/16284
# And this stalled PR: https://github.com/pandas-dev/pandas/pull/16513
# Until that's fixed, we can perform the interpolation column by column using the
# apply method. This is SLOW, but it does what we want

def interp_linear_with_extrap(series):
    """Linearly interpolate a series with extrapolation outside the series range"""
    x = series[~series.isnull()].index.values
    y = series[~series.isnull()].values
    f = interp1d(x, y, kind='linear', fill_value='extrapolate')
    return f(series.index)

# Apply the interpolation across all columns
extrapolated = by_node_idx.apply(interp_linear_with_extrap)

# Look at the same ranges as before to verify that they've been filled in
extrapolated.loc[by_node_idx.isnull().any(axis='columns'), by_node_idx.isnull().any(axis='rows')]

metric,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa,fa
tractID,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Left Arcuate,Left Arcuate,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Corticospinal,...,Right Thalamic Radiation,Right Thalamic Radiation,Right Thalamic Radiation,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate
subjectID,subject_013,subject_039,subject_061,subject_070,subject_051,subject_061,subject_019,subject_060,subject_061,subject_000,...,subject_073,subject_075,subject_076,subject_019,subject_026,subject_042,subject_056,subject_062,subject_065,subject_073
nodeID,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
0,0.222124,0.204551,0.529732,0.223427,0.179262,0.177054,0.180685,0.19007,0.187525,0.302915,...,0.230178,0.191712,0.189012,0.171265,0.179535,0.178694,0.18361,0.184951,0.18625,0.176448
1,0.274187,0.262941,0.505873,0.271225,0.236669,0.236865,0.210511,0.20222,0.222604,0.341655,...,0.270223,0.209479,0.266586,0.236529,0.221781,0.224746,0.218422,0.216726,0.213763,0.223417
2,0.32625,0.307926,0.482015,0.319023,0.29887,0.288474,0.233045,0.214528,0.262054,0.380395,...,0.314196,0.226802,0.326556,0.301793,0.257182,0.27158,0.249845,0.248501,0.245365,0.267737
3,0.337527,0.33255,0.458337,0.34163,0.352558,0.330694,0.247546,0.223219,0.29099,0.4023,...,0.345945,0.239311,0.361582,0.354039,0.272753,0.301833,0.279195,0.268104,0.273859,0.298692
4,0.341623,0.340057,0.453747,0.350528,0.392831,0.378554,0.264519,0.228656,0.310104,0.42918,...,0.370762,0.24908,0.38817,0.370351,0.282364,0.32541,0.316253,0.278656,0.308146,0.320947
5,0.358933,0.343922,0.425955,0.369938,0.421998,0.423745,0.280566,0.232134,0.324181,0.468649,...,0.386055,0.258027,0.39656,0.384095,0.293963,0.34527,0.342079,0.282399,0.342367,0.341777
6,0.389969,0.37055,0.402147,0.392694,0.431251,0.442995,0.298984,0.23785,0.335728,0.511123,...,0.396626,0.267588,0.391463,0.397777,0.302578,0.367088,0.347042,0.286884,0.36962,0.361118
11,0.48132,0.476745,0.438012,0.464462,0.438958,0.51565,0.328839,0.285459,0.357349,0.59587,...,0.402196,0.312212,0.444821,0.436617,0.337558,0.412653,0.294103,0.341477,0.369376,0.376663
19,0.698855,0.655657,0.552387,0.636075,0.523916,0.502649,0.385963,0.320291,0.390599,0.665425,...,0.385826,0.328781,0.357297,0.403516,0.399058,0.379055,0.419434,0.372982,0.447271,0.47836
22,0.701255,0.693943,0.656692,0.699202,0.427661,0.539608,0.402883,0.363187,0.412225,0.652735,...,0.398033,0.342794,0.338804,0.384954,0.417771,0.378874,0.410862,0.402993,0.528836,0.489482


# Restructure node dataframe as a feature matrix

In [41]:
# Now we have the NaN values filled in, we want to structure the nodes dataframe
# as a feature matrix with one row per subject and one column for each 
# combination of metric, tractID, and nodeID
features = extrapolated.stack(['subjectID', 'tractID', 'metric']).unstack(['metric', 'tractID', 'nodeID'])
features.head()

features.loc[features.isnull().any(axis='columns'), features.isnull().any(axis='rows')]

metric,ad,cl,fa,md,rd,ad,cl,fa,md,rd,...,ad,cl,fa,md,rd,ad,cl,fa,md,rd
tractID,Left Arcuate,Left Arcuate,Left Arcuate,Left Arcuate,Left Arcuate,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,Left Cingulum Hippocampus,...,Right Cingulum Cingulate,Right Cingulum Cingulate,Right Cingulum Cingulate,Right Cingulum Cingulate,Right Cingulum Cingulate,Right SLF,Right SLF,Right SLF,Right SLF,Right SLF
nodeID,0,0,0,0,0,0,0,0,0,0,...,99,99,99,99,99,99,99,99,99,99
subjectID,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
subject_023,0.852916,0.045554,0.176999,0.734345,0.675059,0.977274,0.065991,0.186132,0.820333,0.741862,...,0.853455,0.053262,0.190378,0.724465,0.659971,0.884258,0.07498,0.175159,0.742691,0.671908
subject_027,,,,,,0.962808,0.05072,0.20829,0.811707,0.736156,...,0.83932,0.048976,0.181425,0.719115,0.659012,0.921127,0.077892,0.181005,0.769035,0.692989
subject_038,0.911449,0.05937,0.18442,0.772645,0.703243,0.878935,0.074144,0.181416,0.736416,0.665157,...,0.939585,0.076427,0.174893,0.789101,0.713859,0.923406,0.054425,0.173639,0.790608,0.724209
subject_040,0.977176,0.085052,0.174533,0.816673,0.736421,1.560208,0.095872,0.189767,1.282849,1.14417,...,,,,,,0.900383,0.076856,0.180645,0.752998,0.679306
subject_041,,,,,,0.911045,0.065033,0.198547,0.76429,0.690912,...,0.871059,0.072684,0.182816,0.730666,0.660469,1.026543,0.06915,0.177091,0.865812,0.785446
subject_042,0.940042,0.07163,0.179833,0.790106,0.715139,0.813971,0.063752,0.176345,0.691247,0.629885,...,0.852317,0.034491,0.19678,0.734492,0.675579,0.942572,0.049164,0.176377,0.808927,0.742105
subject_046,0.960439,0.079244,0.181361,0.801074,0.721392,0.912339,0.078776,0.173535,0.765762,0.692473,...,1.025163,0.072618,0.174074,0.863008,0.781931,0.964137,0.077655,0.191413,0.801165,0.719679
subject_049,0.947209,0.068121,0.170057,0.8029,0.730746,1.124104,0.07963,0.179142,0.938235,0.8453,...,0.82405,0.050495,0.146653,0.718678,0.665992,,,,,
subject_053,0.91649,0.050371,0.185794,0.78088,0.713075,0.781525,0.055183,0.195107,0.660691,0.600274,...,0.879706,0.067085,0.196715,0.736693,0.665186,0.986633,0.072076,0.187604,0.827019,0.747212
subject_056,0.960623,0.071695,0.180217,0.807981,0.73166,,,,,,...,0.962687,0.044846,0.195716,0.821065,0.750254,0.943617,0.07694,0.183124,0.788895,0.711534


In [42]:
# We're almost there. It'd be nice if the multi-indexed columns were ordered well
# So let's reorder the columns
new_columns = pd.MultiIndex.from_product(features.columns.levels, names=['metric', 'tractID', 'nodeID'])
features = features.loc[:, new_columns]
features.head()

metric,ad,ad,ad,ad,ad,ad,ad,ad,ad,ad,...,rd,rd,rd,rd,rd,rd,rd,rd,rd,rd
tractID,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,Callosum Forceps Major,...,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate,Right Uncinate
nodeID,0,1,2,3,4,5,6,7,8,9,...,90,91,92,93,94,95,96,97,98,99
subjectID,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
subject_000,0.805295,0.84256,0.890478,0.918051,0.928333,0.928146,0.939865,0.966454,1.009985,1.055974,...,0.52344,0.527103,0.534135,0.538084,0.537271,0.541258,0.557063,0.580076,0.61598,0.669231
subject_001,0.79959,0.830287,0.853408,0.878371,0.904229,0.917874,0.920824,0.907716,0.920902,0.958725,...,0.471962,0.485115,0.507145,0.531644,0.55726,0.58149,0.604882,0.627365,0.649681,0.665835
subject_002,0.818287,0.862789,0.91715,0.942286,0.959843,0.956531,0.938447,0.927608,0.938578,0.965086,...,0.532762,0.529145,0.530006,0.540152,0.555245,0.568348,0.581784,0.600493,0.622814,0.646989
subject_003,0.819124,0.849479,0.902716,0.940494,0.959856,0.960385,0.983349,1.026268,1.060786,1.074687,...,0.490735,0.5057,0.518365,0.522972,0.526619,0.539454,0.56035,0.584355,0.621255,0.652261
subject_004,0.816927,0.870511,0.913039,0.935921,0.967326,0.974853,0.98078,0.994712,1.012808,1.039488,...,0.452986,0.467585,0.488844,0.512869,0.538053,0.565963,0.589491,0.606863,0.635141,0.658221


In [43]:
# Lastly, there may still be some nan values. After interpolating
# above, the only NaN values left should be the one created after
# stacking and unstacking due to a subject missing an entire tract.
# In this case, for each missing column, we take the median value
# of all other subjects as the fillna value
features.fillna(features.median(), inplace=True)

In [44]:
# Now we have a well structured feature matrix without null values
print('et voilà')

et voilà


# Get group indices for sparse-group lasso

`bundle_group_idx` will store the group indices for the groups associated with bundle-metric combinations.
There should be $n_{\textrm{bundles}} \times n_{\textrm{metrics}}$ of these groups and we can create the
indices by unfolding the bundle and metric levels of the `features.columns` MultiIndex.

In [45]:
metric_level = features.columns.names.index('metric')
tract_level = features.columns.names.index('tractID')
n_tracts = len(features.columns.levels[tract_level])
bundle_group_membership = np.array(
    features.columns.labels[metric_level] * n_tracts + features.columns.labels[tract_level],
    dtype=np.int64
)

`node_group_membership` will hold the proximity group associations for each feature column.

Suppose there are $n_{\textrm{nodes}}$ per tract and we want to group nodes by proximity into groups of size $k$. Then there are $J = n_{\textrm{nodes}} - k + 1$ of these proximity groups per tract. For example, with 10 nodes labelled $i = 0, \ldots, 9$, and a proximity group size of $k = 3$, we will have 8 groups, labelled $j = 0, \ldots, 7$. In this example the node IDs would map to the following proximity groups:

\begin{array}{|c|l|}
    \hline
    \textrm{nodeID} & \textrm{proximity group membership} \\
    \hline
    0 & (0) \\
    1 & (0, 1) \\
    2 & (0, 1, 2) \\
    3 & (1, 2, 3) \\
    4 & (2, 3, 4) \\
    5 & (3, 4, 5) \\
    6 & (4, 5, 6) \\
    7 & (5, 6, 7) \\
    8 & (6, 7) \\
    9 & (7) \\
    \hline
\end{array}

By inspection, we see that the number of proximity groups associated with node $i$ is given by

\begin{equation}
    \xi_i = \min(i + 1, k, n_{\textrm{nodes}} - i),
\end{equation}

and that the map from `nodeID` to the tuple of associated proximity groups is given by

\begin{equation}
    \gamma_i = \bigl[0, \ldots, \min(i + 1, k, n_{\textrm{nodes}} - i)\bigr) + \max(0, i + 1 - k),
\end{equation}

Where $[ \cdots )$ represents an integer range that is inclusive on the left and exclusive on the right (mimicking the behavior of `np.arange` or python's built-in `range`) and adding a scalar to this range means adding that scalar element-wise to all of the elements.

Lastly, when we finish assigning groups for the first bundle, we want to continue to increment proximity group indices even as we reset to the next bundle. So we add $J$ to the starting index each time we step over to a new bundle.

In the code snippet below, $n_{\textrm{nodes}} \rightarrow$ `n_nodes`, $k \rightarrow$ `group_len`,
$J \rightarrow$ `n_prox_groups`.

In [46]:
node_level = features.columns.names.index('nodeID')
n_nodes = len(features.columns.levels[node_level])
prox_group_len = 3
# group_len should be odd for symmetry considerations later on in the code.
assert(prox_group_len % 2)
n_prox_groups = n_nodes - prox_group_len + 1

node_idx = np.array(features.columns.labels[node_level], dtype=np.int64)

prox_group_membership = [
    np.arange(0, min(nid + 1, prox_group_len, n_nodes - nid)) + max(0, nid + 1 - prox_group_len)
    + bundle_group_membership[i] * n_prox_groups
    for i, nid in enumerate(node_idx)
]

## Read in the targets that we'd like to predict

In [48]:
# Read in the subjects.csv as the output matrix
targets = pd.read_csv('data/subjects.csv', index_col='subjectID').drop(['Unnamed: 0'], axis='columns')
targets.head()

Unnamed: 0_level_0,Age,Gender,Handedness,IQ,IQ_Matrix,IQ_Vocab
subjectID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
subject_000,20,Male,,139.0,65.0,77.0
subject_001,31,Male,,129.0,58.0,74.0
subject_002,18,Female,,130.0,63.0,70.0
subject_003,28,Male,Right,,,
subject_004,29,Male,,,,


We now have a feature matrix and some targets. We may want to fill in the null values in the target matrix as well. That part is much more context-dependent. After that, we are ready to start learning about our data.