In [2]:
import pandas as pd
import numpy as np
import os.path as osp
import os

import nibabel as nib
from nilearn import datasets

  from ._conv import register_converters as _register_converters
  return f(*args, **kwds)


In [3]:
def get_atlas_info(atlas_name, atlas_dir='/data/rmthomas/nilearn_data'): # HO_cort_maxprob_thr25-2mm

    # Choose one of the atlases (add more when necessary)
    # 1. AAL
    # 2. HO_cort_maxprob_thr25-2mm


    # Check if valid atlas name
    if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400']:
        raise ValueError('atlas_name not found')

    if atlas_name == 'AAL':
        dataset = datasets.fetch_atlas_aal(version='SPM12')
        atlas_filename = dataset.maps
        labels = dataset.labels
        indices = dataset.indices
        
    if atlas_name == 'HO_cort_maxprob_thr25-2mm':
        dataset = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-2mm')
        atlas_filename = dataset.maps
        labels = dataset.labels[1:] # the first element is background
        indices = list(range(1,49)) # 1 to 48 regions
        
    if atlas_name == 'schaefer_100':
        atlas_filename = osp.join(atlas_dir,'schaefer/Schaefer2018_100Parcels_17Networks_order_FSLMNI152_2mm.nii')
        labels = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_100Parcels_17Networks_table.csv'))['label']
        indices = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_100Parcels_17Networks_table.csv'))['value']
        
    if atlas_name == 'schaefer_400':
        atlas_filename = osp.join(atlas_dir, 
                                       'schaefer/Schaefer2018_400Parcels_17Networks_order_FSLMNI152_2mm.nii')
        labels = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_400Parcels_17Networks_table.csv'))['label']
        indices = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_400Parcels_17Networks_table.csv'))['value']
        

    if len(labels) != len(indices):
        raise ValueError('Labels and indices should be of same size')
        
    return  labels, indices



In [4]:
def get_subj_file(sub_id,
                  atlas_name,
                  root_dir='/data_local/deeplearning/ABIDE_LC/raw/',
                 ):
    
    if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400']:
        raise ValueError('atlas_name not found')

    if atlas_name == 'AAL':
        filename = f'{sub_id}_aal.nii.gz'
        
    if atlas_name == 'HO_cort_maxprob_thr25-2mm':
        filename = f'{sub_id}_ho_cortical.nii.gz'
        
    if atlas_name == 'schaefer_100':
        filename = f'{sub_id}_schaefer100.nii.gz'
        
    if atlas_name == 'schaefer_400':
        filename = f'{sub_id}_schaefer400.nii.gz'
        
    
    sub_atlas = nib.load(osp.join(root_dir, sub_id, 'atlases', filename)).get_data()
    sub_GM = nib.load(osp.join(root_dir, sub_id, 'T1/fast', f'{sub_id}_T1_restore_brain_pve_1.nii.gz' )).get_data()
    
    return sub_atlas, sub_GM

In [7]:
def create_roi_metrics(sub_id, atlas_name, labels, values):
                       
    # Intialize a data frame
    df = pd.DataFrame(np.zeros((len(labels),2)), index=labels, columns=['mean', 'std', 'var'])

    # Get subject specific atlas file and GM
    subj_atlas, subj_GM = get_subj_file(sub_id, atlas_name)

    # Iterate over atlas values
    for idx, val in zip(labels, values):

        # Create a mask for each value
        masked_voxels = ((subj_atlas==val) & (subj_GM>0.0)) # Metrics over all non-zero voxels
        subj_GM = subj_GM[masked_voxels]
        df['mean'].loc[idx] = subj_GM.mean() 
        df['std'].loc[idx] = subj_GM.std()
        df['var'].loc[idx] = subj_GM.var()

    return df 

In [9]:
def make_GM_files(atlas_name,
                  root_dir='/data_local/deeplearning/ABIDE_LC',
                  output_dir='/data_local/deeplearning/ABIDE_ML_inputs/',
                  subject_list_file='list_2169'):
    
    sub_ids = np.loadtxt(osp.join(root_dir, subject_list_file), dtype='str')
    nsubjects = len(sub_ids)
    
    # get the atlas roi names and values
    labels, values = get_atlas_info(atlas_name=atlas_name)
    
    for i_sub, sub_id in enumerate(sub_ids):
        if i_sub%100 == 0:
            print(f'{i_sub+1}/{nsubjects}')
        
        df = create_roi_metrics(sub_id, atlas_name, labels, values)
        subj_dir = osp.join(output_dir, sub_id, 'gray_matter', atlas_name)
        
        if not osp.exists(subj_dir):
            os.makedirs(subj_dir)
    
        outfile = osp.join(subj_dir, 'gm_metrics.csv')
        df.to_csv(outfile)

In [10]:
#['AAL', 'schaefer_100', 'HO_cort_maxprob_thr25-2mm', 'schaefer_400']
make_GM_files('AAL')

1/2169


ValueError: Shape of passed values is (2, 116), indices imply (3, 116)