In [31]:
import pandas as pd
import numpy as np
import os.path as osp

import nibabel as nib
from nilearn import datasets

In [52]:
def get_atlas_info(atlas_name, atlas_dir='/data/rmthomas/nilearn_data'): # HO_cort_maxprob_thr25-2mm

    # Choose one of the atlases (add more when necessary)
    # 1. AAL
    # 2. HO_cort_maxprob_thr25-2mm


    # Check if valid atlas name
    if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400']:
        raise ValueError('atlas_name not found')

    if atlas_name == 'AAL':
        dataset = datasets.fetch_atlas_aal(version='SPM12')
        atlas_filename = dataset.maps
        labels = dataset.labels
        indices = dataset.indices
        
    if atlas_name == 'HO_cort_maxprob_thr25-2mm':
        dataset = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-2mm')
        atlas_filename = dataset.maps
        labels = dataset.labels[1:] # the first element is background
        indices = list(range(1,49)) # 1 to 48 regions
        
    if atlas_name == 'schaefer_100':
        atlas_filename = osp.join(atlas_dir,'schaefer/Schaefer2018_100Parcels_17Networks_order_FSLMNI152_2mm.nii')
        labels = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_100Parcels_17Networks_table.csv'))['label']
        indices = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_100Parcels_17Networks_table.csv'))['value']
        
    if atlas_name == 'schaefer_400':
        atlas_filename = osp.join(atlas_dir, 
                                       'schaefer/Schaefer2018_400Parcels_17Networks_order_FSLMNI152_2mm.nii')
        labels = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_400Parcels_17Networks_table.csv'))['label']
        indices = pd.read_csv(osp.join(atlas_dir, 
                                           'schaefer/Schaefer2018_400Parcels_17Networks_table.csv'))['value']
        

    if len(labels) != len(indices):
        raise ValueError('Labels and indices should be of same size')
        
    return  labels, indices



In [54]:
def get_subj_file(sub_id,
                  atlas_name,
                  root_dir='/data_local/deeplearning/ABIDE_LC/raw/',
                 ):
    
    if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400']:
        raise ValueError('atlas_name not found')

    if atlas_name == 'AAL':
        filename = f'{sub_id}_aal_nii.gz'
        
    if atlas_name == 'HO_cort_maxprob_thr25-2mm':
        filename = f'{sub_id}_ho_cortical.nii.gz'
        
    if atlas_name == 'schaefer_100':
        filename = f'{sub_id}_schaefer100.nii.gz'
        
    if atlas_name == 'schaefer_400':
        filename = f'{sub_id}_schaefer400.nii.gz'
        
    
    sub_atlas = osp.join(root_dir, sub_id, 'atlases', filename)
    sub_GM = osp.join(root_dir, sub_id, 'T1/fast', f'{sub_id}_T1_restore_brain_pve_1.nii.gz' )
    
    return sub_atlas, sub_GM

In [20]:
xx=datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-2mm')

In [32]:
x=nib.load(xx.maps)

In [34]:
v=x.get_data()

In [43]:
list(range(1,49))

[1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48]