In [None]:
import glob
import os
import re
import shutil
import random
import string

import numpy as np
import pandas as pd

import nilearn
from nilearn import plotting, image
from nilearn.input_data import NiftiMasker
import nibabel as nib
from nipype.interfaces import ants
import nighres

import subprocess
import json
import multiprocessing as mp
from functools import partial
import joblib
from joblib import Parallel, delayed
import itertools

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def find_rois(sub, atlas_name='MASSP', space='T1w'):
    # THALAMUS SUBREGIONS
    if atlas_name=='THAL':
        if space == 'MNI152NLin2009cAsym' or space == 'mni':
            print('')
        else:
            mask_dir = f'../derivatives/masks_thal_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
            fns = sorted(glob.glob(mask_dir))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    # ATAG ATLAS        
    elif atlas_name == 'ATAG':
        if space == 'MNI152NLin2009cAsym' or space == 'mni':
            ### Rois in MNI09c-space
            mask_dir='/home/Public/trondheim/sourcedata/masks/MNI152NLin2009cAsym_res-1p5'
            fns = sorted(glob.glob(mask_dir + '/space-*'))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_res-1p5_label-(?P<label>[a-zA-Z0-9]+)_probseg_def-img.nii.gz', fn).groupdict()['label'] for fn in fns]
        else:
            mask_dir = f'../derivatives/masks_atag_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
            fns = sorted(glob.glob(mask_dir))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>[a-zA-Z0-9]+).nii.gz', fn).groupdict()['label'] for fn in fns]
    # MASSP ATLAS        
    elif atlas_name == 'MASSP':
        mask_dir = f'../derivatives/masks_massp_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    # HARVARD OXFORD ATLAS
    elif atlas_name == 'CORT':
        mask_dir = f'../derivatives/masks_cortex_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    #Pauli atlas
    elif atlas_name == 'Pauli':
        mask_dir = f'../derivatives/masks_Pauli_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    #constructed FPN masks from Brodmann areas (Pijnenburg 2022)
    elif atlas_name == 'FPN':
        mask_dir = f'../derivatives/masks_FPN_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    elif atlas_name == 'WM-rep':
        mask_dir = f'../derivatives/masks_WM-rep_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    elif atlas_name == 'HCP_MMP1':
        mask_dir = f'../derivatives/masks_HCP_MMP1_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
    elif atlas_name == 'str':
        mask_dir = f'../derivatives/masks_str_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]

    roi_dict = dict(zip(names, fns))
    return roi_dict

def load_atlas(sub, atlas_name='MASSP', space='T1w'):
    from nilearn import image
    
    roi_dict = find_rois(sub, atlas_name, space)
    if len(roi_dict) == 0:
        warnings.warn(f'No ROIs found for sub-{sub} atlas-{atlas_name} space-{space}. Returning 0 to prevent error')
        return 0
    combined = image.concat_imgs(roi_dict.values())
    
    class AttrDict(dict):
        def __init__(self, *args, **kwargs):
            super(AttrDict, self).__init__(*args, **kwargs)
            self.__dict__ = self
            
    roi_atlas = AttrDict({'maps': combined,
                          'labels': roi_dict.keys()})
    
    return roi_atlas

# 1. Extract signals from each ROI
## Manual coded extraction - Slow

In [None]:
def get_epi(sub, ses, task, run, use_hp=False, base_dir='../derivatives/fmriprep/fmriprep'):
    if use_hp:
        epi = os.path.join('../derivatives/high_passed_func', f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    else:
        epi = os.path.join(base_dir, f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    return epi

def _make_psc(data):
    mean_img = image.mean_img(data)

    # Replace 0s for numerical reasons
    mean_data = mean_img.get_fdata()
    mean_data[mean_data == 0] = 1
    denom = image.new_img_like(mean_img, mean_data)

    return image.math_img('data / denom[..., np.newaxis] * 100 - 100',
                          data=data, denom=denom)

def do_extract(to_run, atlas_name='MASSP', overwrite=False, to_psc=False, use_hp=False):
    sub, ses, task, run = to_run
    sub = str(sub).zfill(3)
    print(f'Extracting from sub-{sub}/ses-{ses}/sub-{sub}_ses-{ses}_task-{task}_run-{run}', end='')
    
    epi_fn = get_epi(sub,ses,task,run,use_hp)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))
        return None
    
    ## dont really need to convert to psc here
    if to_psc:
        epi = _make_psc(epi_fn)
        psc_fn = '_psc'
    else:
        epi = nib.load(epi_fn)
        psc_fn = ''
    
    # might wanna ahve the hp data handy
    if use_hp:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_name}-signals{psc_fn}_hp.tsv'
    else:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_name}-signals{psc_fn}.tsv'
    
    if os.path.exists(output_fn) and not overwrite:
        print(f'{output_fn} already run, loading previous result...')
        return pd.read_csv(output_fn, sep='\t')
    
    #load & reshpae
    epi_flat = epi.get_fdata().reshape((np.product(epi.shape[:3]), epi.shape[-1]))

    # load atlas
    atlas = load_atlas(sub,atlas_name=atlas_name)
    
    dfs = []
    for i in np.arange(len(atlas.labels)):
        print('.', end='')
        label = list(atlas.labels)[i]
        mask = image.index_img(atlas.maps, i)
        mask_flat = mask.get_fdata().ravel()
        print(f'There are {np.count_nonzero(mask_flat)} voxels in region {label}')
        if mask_flat.sum() == 0: # if there are no voxels in the mask then add one voxel so code doesn't crash
            mask_flat[-1] = 1
#         print(mask_flat)
#         print(label)
        print(len(epi_flat))
        print(len(mask_flat))
        signal = pd.DataFrame(np.average(epi_flat, weights=mask_flat, axis=0), columns=[label])
        signal.index.name = 'volume'
        dfs.append(signal)

    df = pd.concat(dfs, axis=1)
    if not os.path.exists(os.path.dirname(output_fn)):
        os.makedirs(os.path.dirname(output_fn))
    df.to_csv(output_fn, sep='\t')
    print(output_fn)
    return df

In [None]:
# find all available functional runs, extract sub/ses/task/run info
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-*/func/*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
all_combs[-10:]

In [None]:
# just extract MSIT
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-sstmsit/func/*task-*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
# all_combs = [x for x in all_combs if not '015' in x] # bad hp data for sub 15????
all_combs = [x for x in all_combs if '041' in x] # bad hp data for sub 26????
# all_combs = [x for x in all_combs if x[0] in ['002','003','004','005','006','007','008','009','010','011']]
all_combs

In [None]:
# just extract SST
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-sstmsit/func/*task-msit*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
# all_combs = [x for x in all_combs if not '015' in x] # bad hp data for sub 15????
# all_combs = [x for x in all_combs if '041' in x] # bad hp data for sub 26????
all_combs = [x for x in all_combs if x[0] in ['004','008','010','013','019','027']]
all_combs

In [None]:
# just extract RBREVL
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-rbrevl/func/*task-rb*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
# all_combs = [x for x in all_combs if not '015' in x] # bad hp data for sub 15????
all_combs = [x for x in all_combs]# if '041' in x] # bad hp data for sub 26????
# all_combs = [x for x in all_combs if x[0] in ['002','003','004','005','006','007','008','009','010','011']]
all_combs

In [None]:
# just extract MSIT
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-rlsat/func/*task-*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
all_combs

In [None]:
# all_subs = np.arange(2,28)
# all_ses = ['rlsat', 'rbrevl', 'anatomical', 'sstmsit']
# all_tasks = ['rs', 'rlsat', 'rb', 'revl', 'sst', 'msit']
# all_runs = [1,2,3]

# all_combs = list(itertools.product(all_subs,all_ses,all_tasks,all_runs))
# all_combs = [x for x in all_combs if (x[1]=='rlsat' and x[2]=='rlsat') or (x[1]=='rbrevl' and x[2] in ['rb', 'revl'] and x[3]<3) or (x[1]=='sstmsit' and x[2] in ['sst', 'msit'] and x[3]<3) or (x[1]=='anatomical' and x[2]=='rs' and x[3]<3)]
# #do_extract(all_combs[0], overwrite=True)

In [None]:
def check_affines(sub):
    sub = str(sub).zfill(3)
    all_funcs = sorted(glob.glob(f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses*/func/sub*_space-T1w_desc-preproc_bold.nii.gz'))
    all_affines = [nib.load(x).affine for x in all_funcs]
    return (np.array(all_affines)[0] == np.array(all_affines)).all()

In [None]:
#all_atlases=['MASSP','CORT','ATAG']#,'THAL'] #['Pauli']
# all_atlases=['HCP_MMP1']
# all_atlases = ['CORT']
all_atlases = ['str','MASSP','CORT','ATAG']

hp_options= [True,False]
overwrite=False
psc=False

for atlas_name in all_atlases:
    for hp in hp_options:
        for i, comb in enumerate(all_combs):
            print(f'atlas-{atlas_name} hp-{hp}')
            print(comb)
            sub = comb[0]
            if check_affines(sub):
                do_extract(comb, atlas_name=atlas_name, overwrite=overwrite, to_psc=psc, use_hp=hp)
            else:
                print(f'Affines for sub {sub} not identical')

In [None]:
#all_atlases=['MASSP','CORT','ATAG']#,'THAL'] #['Pauli']
all_atlases=['THAL']

hp_options= [False]#[True,False]
overwrite=False
psc=False

for atlas_name in all_atlases:
    for hp in hp_options:
        for i, comb in enumerate(all_combs):
            print(f'atlas-{atlas_name} hp-{hp}')
            print(comb)
            sub = comb[0]
            if check_affines(sub):
                do_extract(comb, atlas_name=atlas_name, overwrite=overwrite, to_psc=psc, use_hp=hp)
            else:
                print(f'Affines for sub {sub} not identical')

In [None]:
#all_atlases=['MASSP','CORT','ATAG']#,'THAL'] #['Pauli']
all_atlases=['WM-rep']

hp_options= [False]#[True,False]
overwrite=False
psc=False

for atlas_name in all_atlases:
    for hp in hp_options:
        for i, comb in enumerate(all_combs):
            print(f'atlas-{atlas_name} hp-{hp}')
            print(comb)
            sub = comb[0]
            print(sub)
            if check_affines(sub):
                do_extract(comb, atlas_name=atlas_name, overwrite=overwrite, to_psc=psc, use_hp=hp)
            else:
                print(f'Affines for sub {sub} not identical')

In [None]:
all_funcs = sorted(glob.glob(f'../derivatives/fmriprep/fmriprep/sub-026/ses*/func/sub*_space-T1w_desc-preproc_bold.nii.gz'))
all_funcs

In [None]:
check_affines('026')

# 2 extract whole roi signal from thalamus

In [None]:
def find_rois(sub, atlas_name='MASSP', space='T1w'):
    if atlas_name=='THAL':
        if space == 'MNI152NLin2009cAsym' or space == 'mni':
            print('')
        else:
            mask_dir = f'../derivatives/masks_thal_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
            fns = sorted(glob.glob(mask_dir))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]
            
    if atlas_name == 'ATAG':
        if space == 'MNI152NLin2009cAsym' or space == 'mni':
            ### Rois in MNI09c-space
            mask_dir='/home/Public/trondheim/sourcedata/masks/MNI152NLin2009cAsym_res-1p5'
            fns = sorted(glob.glob(mask_dir + '/space-*'))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_res-1p5_label-(?P<label>[a-zA-Z0-9]+)_probseg_def-img.nii.gz', fn).groupdict()['label'] for fn in fns]
        else:
            mask_dir = f'../derivatives/masks_atag_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
            fns = sorted(glob.glob(mask_dir))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>[a-zA-Z0-9]+).nii.gz', fn).groupdict()['label'] for fn in fns]
            
    elif atlas_name == 'MASSP':
        mask_dir = f'../derivatives/masks_massp_func/sub-{sub}/anat/sub-{sub}_*Tha*.nii.gz' # only thalamus
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]

    roi_dict = dict(zip(names, fns))
    return roi_dict

def load_atlas(sub, atlas_name='MASSP', space='T1w'):
    from nilearn import image
    
    roi_dict = find_rois(sub, atlas_name, space)
    combined = image.concat_imgs(roi_dict.values())
    
    class AttrDict(dict):
        def __init__(self, *args, **kwargs):
            super(AttrDict, self).__init__(*args, **kwargs)
            self.__dict__ = self
            
    roi_atlas = AttrDict({'maps': combined,
                          'labels': roi_dict.keys()})
    
    return roi_atlas

def get_epi(sub, ses, task, run, use_hp=False, base_dir='../derivatives/fmriprep/fmriprep'):
    if use_hp:
        epi = os.path.join('../derivatives/high_passed_func', f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    else:
        epi = os.path.join(base_dir, f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    return epi

def _make_psc(data):
    mean_img = image.mean_img(data)

    # Replace 0s for numerical reasons
    mean_data = mean_img.get_fdata()
    mean_data[mean_data == 0] = 1
    denom = image.new_img_like(mean_img, mean_data)

    return image.math_img('data / denom[..., np.newaxis] * 100 - 100',
                          data=data, denom=denom)

def do_extract(to_run, atlas='MASSP', overwrite=False, to_psc=False, use_hp=False):
    sub, ses, task, run = to_run
    sub = str(sub).zfill(3)
    print(f'Extracting from sub-{sub}/ses-{ses}/sub-{sub}_ses-{ses}_task-{task}_run-{run}', end='')
    
    epi_fn = get_epi(sub,ses,task,run,use_hp)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))
        return None

    if atlas == 'thal':
        toappend = '_thalamus'
    else: 
        toappend=''
    
    ## dont really need to convert to psc here
    if to_psc:
        epi = _make_psc(epi_fn)
        psc_fn = '_psc'
    else:
        epi = nib.load(epi_fn)
        psc_fn = ''
    
    # might wanna ahve the hp data handy
    if use_hp:
        output_fn = f'../derivatives/extracted_signals_thal_voxels/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas}-signals{psc_fn}_hp.tsv'
    else:
        output_fn = f'../derivatives/extracted_signals_that_voxels/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas}-signals{psc_fn}.tsv'
    
    if os.path.exists(output_fn) and not overwrite:
        print(f'{output_fn} already run, loading previous result...')
        return pd.read_csv(output_fn, sep='\t')
    
    #epi = nib.load(epi_fn)
    epi_flat = epi.get_fdata().reshape((np.product(epi.shape[:3]), epi.shape[-1]))

    # load atlas
    atlas = load_atlas(sub,atlas_name=atlas)
    
    dfs = []
    for i in np.arange(len(atlas.labels)):
        print('.', end='')
        label = list(atlas.labels)[i]
        mask = image.index_img(atlas.maps, i)
        mask_flat = mask.get_fdata().ravel()
        indexes = np.where(mask_flat>0)[0] # indexes of voxel within mask. len(indexes) and np.count_nonzero(mask_flat) should be the same

        print(f'There are {np.count_nonzero(mask_flat)} voxels in region {label} for sub {sub}')
        for label_n, inds in enumerate(indexes):
            mask_flat_voxel = mask_flat.copy()
            mask_flat_voxel[:] = 0
            mask_flat_voxel[inds] = 1
            signal = pd.DataFrame(np.average(epi_flat, weights=mask_flat_voxel, axis=0), columns=[label+'_'+str(label_n)])
            signal.index.name = 'volume'
            dfs.append(signal)

    df = pd.concat(dfs, axis=1)
    if not os.path.exists(os.path.dirname(output_fn)):
        os.makedirs(os.path.dirname(output_fn))
    df.to_csv(output_fn, sep='\t')
    print(output_fn)
    return df

In [None]:
for i, comb in enumerate(all_combs):
    print(comb)
    sub = comb[0]
    if check_affines(sub):
        do_extract(comb, atlas='MASSP', overwrite=True, to_psc=False, use_hp=True)
    else:
        print(f'Affines for sub {sub} not identical')

In [None]:
for i, comb in enumerate(all_combs):
    print(comb)
    sub = comb[0]
    if check_affines(sub):
        do_extract(comb, atlas='THAL', overwrite=True, to_psc=False, use_hp=False)
    else:
        print(f'Affines for sub {sub} not identical')

In [None]:
atlas = 'MASSP'
to_run = [('002', 'sstmsit', 'msit', '1')]
use_hp=True
sub='002'
ses='sstmsit'
task='msit'
run='1'
to_psc = False
overwrite=True
sub = str(sub).zfill(3)
print(f'Extracting from sub-{sub}/ses-{ses}/sub-{sub}_ses-{ses}_task-{task}_run-{run}', end='')

epi_fn = get_epi(sub,ses,task,run,use_hp)
if not os.path.exists(epi_fn):
    print('...doesnt exist, skipping'.format(sub,ses,task,run))
#     return None

if atlas == 'thal':
    toappend = '_thalamus'
else: 
    toappend=''

## dont really need to convert to psc here
if to_psc:
    epi = _make_psc(epi_fn)
    psc_fn = '_psc'
else:
    epi = nib.load(epi_fn)
    psc_fn = ''

# might wanna ahve the hp data handy
if use_hp:
    output_fn = f'../derivatives/extracted_signals_thal_voxels/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas}-signals{psc_fn}_hp.tsv'
else:
    output_fn = f'../derivatives/extracted_signals_that_voxels/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas}-signals{psc_fn}.tsv'

if os.path.exists(output_fn) and not overwrite:
    print(f'{output_fn} already run, loading previous result...')
#     return pd.read_csv(output_fn, sep='\t')

#epi = nib.load(epi_fn)
epi_flat = epi.get_fdata().reshape((np.product(epi.shape[:3]), epi.shape[-1]))

# load atlas
atlas = load_atlas(sub,atlas_name=atlas)

dfs = []
for i in np.arange(len(atlas.labels)):
    print('.', end='')
    label = list(atlas.labels)[i]
    mask = image.index_img(atlas.maps, i)
    mask_flat = mask.get_fdata().ravel()
    indexes = np.where(mask_flat>0)[0] # indexes of voxel within mask. len(indexes) and np.count_nonzero(mask_flat) should be the same
    
    print(f'There are {np.count_nonzero(mask_flat)} voxels in region {label} for sub {sub}')
#     for j in np.arange(np.count_nonzero(mask_flat)): ### FINISH THISSS # loop over each voxel in mask .. 
    for i, kk in enumerate(indexes):
        mask_flat_voxel = mask_flat.copy()
        mask_flat_voxel[:] = 0
        mask_flat_voxel[kk] = 1
        signal = pd.DataFrame(np.average(epi_flat, weights=mask_flat_voxel, axis=0), columns=[label+'_'+str(j)])
        signal.index.name = 'volume'
        dfs.append(signal)

# df = pd.concat(dfs, axis=1)
# #     output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-MASSP-signals.tsv'
# if not os.path.exists(os.path.dirname(output_fn)):
#     os.makedirs(os.path.dirname(output_fn))
# df.to_csv(output_fn, sep='\t')
# print(output_fn)
# return df

# 2. Use nilearn, and high-pass filter & remove confounds along the way
### This is much faster, but something very weird happens when multiple atlas maps overlap (eg when there's both a mask for "M1" as well as "rM1" and "lM1")
the M1 case isn't very troubling, but it suggests something funny happens with overlapping maps/masks - do we have overlap? Perhaps the manually coded version is safer

In [None]:
## extract signals this way?
## stolen from nideconv
import pandas as pd
from nilearn import input_data
import nibabel as nb
from nilearn._utils import check_niimg
from nilearn import image
import numpy as np

def extract_timecourse_from_nii(atlas,
                                nii,
                                mask=None,
                                confounds=None,
                                atlas_type=None,
                                t_r=None,
                                low_pass=None,
                                high_pass=1./128,
                                to_psc=False,
                                *args,
                                **kwargs):


    standardize = kwargs.pop('standardize', False)
    detrend = kwargs.pop('detrend', False)

    if atlas_type is None:
        maps = check_niimg(atlas.maps)

        if len(maps.shape) == 3:
            atlas_type = 'labels'
        else:
            atlas_type = 'prob'

    if atlas_type == 'labels':
        masker = input_data.NiftiLabelsMasker(atlas.maps,
                                              mask_img=mask,
                                              standardize=standardize,
                                              detrend=detrend,
                                              t_r=t_r,
                                              low_pass=low_pass,
                                              high_pass=high_pass,
                                              *args, **kwargs)
    else:
        masker = input_data.NiftiMapsMasker(atlas.maps,
                                            mask_img=mask,
                                            standardize=standardize,
                                            detrend=detrend,
                                            t_r=t_r,
                                            low_pass=low_pass,
                                            high_pass=high_pass,
                                            *args, **kwargs)

    if to_psc:
        data = _make_psc(nii)
    else:
        data = nii

    results = masker.fit_transform(data,
                                   confounds=confounds)

    # For weird atlases that have a label for the background
    if len(atlas.labels) == results.shape[1] + 1:
        atlas.labels = atlas.labels[1:]

    if t_r is None:
        t_r = 1
    print(t_r)
    index = pd.Index(np.arange(0,
                               t_r*nib.load(nii).shape[-1],
                               t_r),
                     name='time')

    columns = pd.Index(atlas.labels,
                       name='roi')

    return pd.DataFrame(results,
                        index=index,
                        columns=columns)


In [None]:
def exclude_map_from_atlas(atlas, map_name):
    
    indx = np.where(np.array(list(atlas.labels)) == map_name)[0][0]

    all_indices = np.arange(atlas.maps.shape[-1])
    indices = [x for x in all_indices if not x == indx]

    atlas.maps = nilearn.image.index_img(atlas.maps, indices)
    atlas.labels = np.array(list(atlas.labels))[indices].tolist()
    
    return atlas

def _make_psc(data):
    mean_img = image.mean_img(data)

    # Replace 0s for numerical reasons
    mean_data = mean_img.get_data()
    mean_data[mean_data == 0] = 1
    denom = image.new_img_like(mean_img, mean_data)

    return image.math_img('data / denom[..., np.newaxis] * 100 - 100',
                          data=data, denom=denom)


def extract_signals_nilearn(comb, include_physio=True, space='T1w', overwrite=False):
#for sub, ses, task, run in all_combs:
    sub,ses,task,run = comb
    epi_fn = get_epi(sub,ses,task,run)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))       
        return None
    
    # load confounds
    confounds_fn = f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-confounds_timeseries.tsv'
    confounds = pd.read_csv(confounds_fn, sep='\t')[['trans_x', 'trans_y', 'trans_z', 'rot_x', 'rot_y', 'rot_z', 'dvars', 'framewise_displacement']].fillna(method='bfill')

    # get retroicor
    if include_physio:
        retroicor_fn = f'../derivatives/retroicor/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-retroicor_regressors.tsv'
        if not os.path.exists(retroicor_fn):
            ## take first 20 aCompCor components
            print("No retroicor found, including 20 a_comp_cor components")
            a_comp_cor = pd.read_csv(confounds_fn, sep='\t')[['a_comp_cor_' + str(x).zfill(2) for x in range(20)]]
            confounds = pd.concat([confounds, a_comp_cor], axis=1)
        else:
            retroicor = pd.read_csv(retroicor_fn, sep='\t', header=None).iloc[:,:20]  ## 20 components in total
            retroicor.columns = ['cardiac_' + str(x) for x in range(6)] + ['respiratory_' + str(x) for x in range(8)] + ['respiratoryxcardiac_' + str(x) for x in range(4)] + ['HRV', 'RVT']
            confounds = pd.concat([confounds, retroicor], axis=1)

    # get brain mask
    brain_mask = f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-1_space-{space}_desc-brain_mask.nii.gz'
    
    for atlas_type in ['MASSP', 'ATAG']:
#     for atlas_type in ['ATAG']:
        output_fn = f'../derivatives/extracted_signals_nilearn/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_type}-signals.tsv'
        if os.path.exists(output_fn) and not overwrite:
            return 0
        
        subject_atlas = load_atlas(sub, atlas_name=atlas_type)
        
        if atlas_type == 'ATAG':
            subject_atlas = exclude_map_from_atlas(subject_atlas, 'M1')
        
        df = extract_timecourse_from_nii(subject_atlas, epi_fn, mask=brain_mask, confounds=confounds, high_pass=1/128., t_r=1.38, to_psc=True)
        if not os.path.exists(os.path.dirname(output_fn)):
            os.makedirs(os.path.dirname(output_fn))
        df.to_csv(output_fn, sep='\t')

    return 0
#         print(output_fn)

In [None]:
include_physio = True
space = 'T1w'

all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-*/func/*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]


all_combs = [x for x in all_combs if x[2] == 'msit']
all_combs

In [None]:
out = joblib.Parallel(n_jobs=10, verbose=True)(joblib.delayed(extract_signals_nilearn)(x, overwrite=True) for x in all_combs)

# 3. Use pre-cleaned data, don't extract confounds during the process

In [None]:
include_physio = True
space = 'T1w'

def exclude_map_from_atlas(atlas, map_name):
    
    indx = np.where(np.array(list(atlas.labels)) == map_name)[0][0]

    all_indices = np.arange(atlas.maps.shape[-1])
    indices = [x for x in all_indices if not x == indx]

    atlas.maps = nilearn.image.index_img(atlas.maps, indices)
    atlas.labels = np.array(list(atlas.labels))[indices].tolist()
    
    return atlas

def _make_psc(data):
    mean_img = image.mean_img(data)

    # Replace 0s for numerical reasons
    mean_data = mean_img.get_data()
    mean_data[mean_data == 0] = 1
    denom = image.new_img_like(mean_img, mean_data)

    return image.math_img('data / denom[..., np.newaxis] * 100 - 100',
                          data=data, denom=denom)

# def get_epi_fn(sub,ses,task,run, base_dir='../derivatives/fmriprep/fmriprep')

def extract_signals_nilearn(comb, include_physio=True, space='T1w', overwrite=False, use_precleaned=False, use_confounds=True):
    if use_precleaned and use_confounds:
        raise(IOError('Cannot both use precleaned data AND clean, that''s a stupid idea!'))
    
    sub,ses,task,run = comb
    output_base_dir = '../derivatives/extracted_signals_nilearn'
    if use_precleaned:
        base_dir = '../derivatives/cleaned_func'
        output_base_dir += '_precleaned'
    else:
        base_dir = '../derivatives/fmriprep/fmriprep'
    epi_fn = get_epi(sub,ses,task,run, base_dir=base_dir)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))       
        return None
    
    if use_confounds:
        output_base_dir += '_cleaned'
        # load confounds
        confounds_fn = f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-confounds_timeseries.tsv'
        confounds = pd.read_csv(confounds_fn, sep='\t')[['trans_x', 'trans_y', 'trans_z', 'rot_x', 'rot_y', 'rot_z', 'dvars', 'framewise_displacement']].fillna(method='bfill')

        # get retroicor
        if include_physio:
            retroicor_fn = f'../derivatives/retroicor/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-retroicor_regressors.tsv'
            if not os.path.exists(retroicor_fn):
                ## take first 20 aCompCor components
                print("No retroicor found, including 20 a_comp_cor components")
                a_comp_cor = pd.read_csv(confounds_fn, sep='\t')[['a_comp_cor_' + str(x).zfill(2) for x in range(20)]]
                confounds = pd.concat([confounds, a_comp_cor], axis=1)
            else:
                retroicor = pd.read_csv(retroicor_fn, sep='\t', header=None).iloc[:,:20]  ## 20 components in total
                retroicor.columns = ['cardiac_' + str(x) for x in range(6)] + ['respiratory_' + str(x) for x in range(8)] + ['respiratoryxcardiac_' + str(x) for x in range(4)] + ['HRV', 'RVT']
                confounds = pd.concat([confounds, retroicor], axis=1)

    # get brain mask
    brain_mask = f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-1_space-{space}_desc-brain_mask.nii.gz'
    
    for atlas_type in ['MASSP', 'ATAG']:
        output_fn = os.path.join(output_base_dir, f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_type}-signals.tsv')
        if os.path.exists(output_fn) and not overwrite:
            return 0
        
        subject_atlas = load_atlas(sub, atlas_name=atlas_type)
        if atlas_type == 'ATAG':
            subject_atlas = exclude_map_from_atlas(subject_atlas, 'M1')
        
        df = extract_timecourse_from_nii(subject_atlas, epi_fn, mask=brain_mask, high_pass=None)
        if not os.path.exists(os.path.dirname(output_fn)):
            os.makedirs(os.path.dirname(output_fn))
        df.to_csv(output_fn, sep='\t')

    return 0


In [None]:
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-*/func/*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]

all_combs = [x for x in all_combs if x[1] == 'rlsat']

In [None]:
out = joblib.Parallel(n_jobs=10, verbose=True)(joblib.delayed(extract_signals_nilearn)(x, overwrite=False, use_precleaned=True, use_confounds=False) for x in all_combs)

## Clean niftis
1. High-pass
2. Remove confounds

In [None]:
include_physio=True
def high_pass(nii, verbose=False, mask=None):
    print('Highpass-filtering')
    t_r = nii.header['pixdim'][4]
    if mask is not None:
        hp_masker = NiftiMasker(mask, high_pass=1./128, t_r=t_r)
    else:
        hp_masker = NiftiMasker(high_pass=1./128, t_r=t_r)
    
    # Generate & fit NiftiMasker
    hp_data = hp_masker.fit_transform(nii)
    
    # back to brain space
    inver = hp_masker.inverse_transform(hp_data)

    # add mean of timeseries per voxel back
    highpassed_data = inver.get_fdata() + np.mean(nii.get_fdata(), 3)[:,:,:,np.newaxis]
    highpassed_img = nib.Nifti1Image(highpassed_data, inver.affine, header=nii.header)
    
    return highpassed_img


def do_high_pass(fn, overwrite=False):
    regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w_desc-preproc_bold.*')
    gd = regex.match(fn).groupdict()

    brain_mask = nib.load(fn.replace('preproc_bold', 'brain_mask'))

    # has this file been highpassed?
    hp_save_fn = fn.replace('fmriprep/fmriprep', 'high_passed_func')
    if os.path.exists(hp_save_fn) and not overwrite:
        hp_data = nib.load(hp_save_fn)
    else:
        nii = nib.load(fn)
        hp_data = high_pass(nii, mask=brain_mask)
        os.makedirs(os.path.dirname(hp_save_fn), exist_ok=True)
        hp_data.to_filename('../derivatives/high_passed_func/sub-{}/ses-{}/func/{}'.format(gd['sub'], gd['ses'], os.path.basename(fn)))

        
def do_clean(fn, overwrite=False):
    regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w_desc-preproc_bold.*')
    gd = regex.match(fn).groupdict()
    sub, ses, task, run = gd['sub'], gd['ses'], gd['task'], gd['run']
    
    # has this file been cleaned?
    cleaned_save_fn = fn.replace('high_passed_func', 'cleaned_func')
    if os.path.exists(cleaned_save_fn) and not overwrite:
        print(cleaned_save_fn)
        cleaned_data = nib.load(cleaned_save_fn)
        print('eh')
    else:
        nii = nib.load(fn)
        # load confounds
        confounds_fn = f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-confounds_timeseries.tsv'
        confounds = pd.read_csv(confounds_fn, sep='\t')[['trans_x', 'trans_y', 'trans_z', 'rot_x', 'rot_y', 'rot_z', 'dvars', 'framewise_displacement']].fillna(method='bfill')

        # get retroicor
        if include_physio:
            retroicor_fn = f'../derivatives/retroicor/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-retroicor_regressors.tsv'
            if not os.path.exists(retroicor_fn):
                ## take first 20 aCompCor components
                print("No retroicor found, including 20 a_comp_cor components")
                a_comp_cor = pd.read_csv(confounds_fn, sep='\t')[['a_comp_cor_' + str(x).zfill(2) for x in range(20)]]
                confounds = pd.concat([confounds, a_comp_cor], axis=1)
            else:
                retroicor = pd.read_csv(retroicor_fn, sep='\t', header=None).iloc[:,:20]  ## 20 components in total
                retroicor.columns = ['cardiac_' + str(x) for x in range(6)] + ['respiratory_' + str(x) for x in range(8)] + ['respiratoryxcardiac_' + str(x) for x in range(4)] + ['HRV', 'RVT']
                confounds = pd.concat([confounds, retroicor], axis=1)

        cleaned_data = image.clean_img(nii, confounds=confounds, standardize=False, detrend=False)

        os.makedirs(os.path.dirname(cleaned_save_fn), exist_ok=True)
        cleaned_data.to_filename('../derivatives/cleaned_func/sub-{}/ses-{}/func/{}'.format(gd['sub'], gd['ses'], os.path.basename(fn)))

all_funcs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses*/func/*T1w*_bold.nii.gz'))
# all_funcs = [x for x in all_funcs if not 'sub-001' in x]
# all_funcs

all_highpassed = sorted(glob.glob('../derivatives/high_passed_func/sub*/ses*/func/*'))
all_highpassed = [x for x in all_highpassed if 'rlsat' in x]

out = joblib.Parallel(n_jobs=20, verbose=True)(joblib.delayed(do_clean)(x, overwrite=True) for x in all_highpassed)

# do_clean(all_highpassed[0])

# calculate number of voxels in the IFG

(scott) this was necessary to satisfy review comments, comparing the number of voxels in the IFG when using the HCP_MMP1 atlas vs the Harvard Oxford atlas

In [None]:
def calculate_voxel_numbers(to_run, atlas_name='MASSP', overwrite=False, to_psc=False, use_hp=False):
    sub, ses, task, run = to_run
    sub = str(sub).zfill(3)
    print(f'Extracting from sub-{sub}/ses-{ses}/sub-{sub}_ses-{ses}_task-{task}_run-{run}', end='')
    
    epi_fn = get_epi(sub,ses,task,run,use_hp)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))
        return None
    
    ## dont really need to convert to psc here
    if to_psc:
        epi = _make_psc(epi_fn)
        psc_fn = '_psc'
    else:
        epi = nib.load(epi_fn)
        psc_fn = ''
    
    # might wanna ahve the hp data handy
    if use_hp:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_name}-signals{psc_fn}_hp.tsv'
    else:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_name}-signals{psc_fn}.tsv'
    
    if os.path.exists(output_fn) and not overwrite:
        print(f'{output_fn} already run, loading previous result...')
        return pd.read_csv(output_fn, sep='\t')
    
    #load & reshpae
    epi_flat = epi.get_fdata().reshape((np.product(epi.shape[:3]), epi.shape[-1]))

    # load atlas
    atlas = load_atlas(sub,atlas_name=atlas_name)
    
    dfs = []
    for i in np.arange(len(atlas.labels)):
        print('.', end='')
        label = list(atlas.labels)[i]
        mask = image.index_img(atlas.maps, i)
        mask_flat = mask.get_fdata().ravel()
        if label in ['IFG-l', 'IFG-r','IFGhcp-l','IFGhcp-r']:
            print(f'There are {np.count_nonzero(mask_flat)} voxels in region {label}')
            return np.count_nonzero(mask_flat)



In [None]:
all_atlases=['HCP_MMP1']
# all_atlases = ['CORT']

hp_options= [True]
overwrite=True
psc=False
IFG_vox = []

for atlas_name in all_atlases:
    for hp in hp_options:
        for i, comb in enumerate(all_combs):
            print(f'atlas-{atlas_name} hp-{hp}')
            print(comb)
            sub = comb[0]
            if check_affines(sub):
                IFG_vox.append(calculate_voxel_numbers(comb, atlas_name=atlas_name, overwrite=overwrite, to_psc=psc, use_hp=hp))
            else:
                print(f'Affines for sub {sub} not identical')

In [None]:
# harvard oxford
Harv_Ox_IFG = [1656,1544,1551,1881,1534,1622,1993,1522,2085,1484,1817,1772,2320,
 2221,1910,1496,1549,1867,1968,1997,1485,1923,1780,1808,1626,2113,
 1998,1825,1753,2140,1786,1925,1764,1759,1510,1836,1520]

np.array(Harv_Ox_IFG).mean()
# 1792.972972972973

In [None]:
HCP_MMP1_IFG = [2311,2007,2391,2401,2344,1791,2474,1929,2576,1919,2752,2240,2777,2626,
 2819,1849,2363,2389,2677,2836,2140,2532,2067,2325,2421,2618,2874,2648,
 2578,2893,2536,2457,2626,2068,2080,2318,2247]

np.array(HCP_MMP1_IFG).mean()

# TMP TO RUN NORMAL EXTRACTION

In [None]:
import glob
import os
import re
import shutil
import random
import string
import warnings

import numpy as np
import pandas as pd

import nilearn
from nilearn import plotting, image
from nilearn.input_data import NiftiMasker
import nibabel as nib
from nipype.interfaces import ants
import nighres

import subprocess
import json
import multiprocessing as mp
from functools import partial
import joblib
from joblib import Parallel, delayed
import itertools

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def find_rois(sub, atlas_name='ATAG', space='T1w'):
    if atlas_name == 'ATAG':
        if space == 'MNI152NLin2009cAsym' or space == 'mni':
            ### Rois in MNI09c-space
            mask_dir='/home/Public/trondheim/sourcedata/masks/MNI152NLin2009cAsym_res-1p5'
            fns = sorted(glob.glob(mask_dir + '/space-*'))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_res-1p5_label-(?P<label>[a-zA-Z0-9]+)_probseg_def-img.nii.gz', fn).groupdict()['label'] for fn in fns]
        else:
            mask_dir = f'../derivatives/masks_atag_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
            fns = sorted(glob.glob(mask_dir))
            names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>[a-zA-Z0-9]+).nii.gz', fn).groupdict()['label'] for fn in fns]
    elif atlas_name == 'MASSP':
        mask_dir = f'../derivatives/masks_massp_func/sub-{sub}/anat/sub-{sub}_*.nii.gz'
        fns = sorted(glob.glob(mask_dir))
        names = [re.match('.*space-(?P<space>[a-zA-Z0-9]+)_desc-mask-(?P<label>\S+).nii.gz', fn).groupdict()['label'] for fn in fns]

    roi_dict = dict(zip(names, fns))
    return roi_dict

def load_atlas(sub, atlas_name='MASSP', space='T1w'):
    from nilearn import image
    
    roi_dict = find_rois(sub, atlas_name, space)
    if len(roi_dict) == 0:
        warnings.warn(f'No ROIs found for sub-{sub} atlas-{atlas_name} space-{space}. Returning 0 to prevent error')
        return 0
    combined = image.concat_imgs(roi_dict.values())
    
    class AttrDict(dict):
        def __init__(self, *args, **kwargs):
            super(AttrDict, self).__init__(*args, **kwargs)
            self.__dict__ = self
            
    roi_atlas = AttrDict({'maps': combined,
                          'labels': roi_dict.keys()})
    
    return roi_atlas

In [None]:
def get_epi(sub, ses, task, run, use_hp=False, base_dir='../derivatives/fmriprep/fmriprep'):
    if use_hp:
        epi = os.path.join('../derivatives/high_passed_func', f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    else:
        epi = os.path.join(base_dir, f'sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_space-T1w_desc-preproc_bold.nii.gz')
    return epi

def _make_psc(data):
    mean_img = image.mean_img(data)

    # Replace 0s for numerical reasons
    mean_data = mean_img.get_fdata()
    mean_data[mean_data == 0] = 1
    denom = image.new_img_like(mean_img, mean_data)

    return image.math_img('data / denom[..., np.newaxis] * 100 - 100',
                          data=data, denom=denom)

def do_extract(to_run, atlas_name, overwrite=False, to_psc=False, use_hp=False):
    sub, ses, task, run = to_run
    sub = str(sub).zfill(3)
    print(f'Extracting from sub-{sub}/ses-{ses}/sub-{sub}_ses-{ses}_task-{task}_run-{run}', end='')
    
    epi_fn = get_epi(sub,ses,task,run,use_hp)
    if not os.path.exists(epi_fn):
        print('...doesnt exist, skipping'.format(sub,ses,task,run))
        return None
    
    # load atlas
    atlas = load_atlas(sub, atlas_name=atlas_name)
    if atlas == 0:
        warnings.warn('No atlas found! skipping')
        return None

    if to_psc:
        epi = _make_psc(epi_fn)
        psc_fn = '_psc'
    else:
        epi = nib.load(epi_fn)
        psc_fn = ''
    
    if use_hp:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atasl_name}-signals{psc_fn}_hp.tsv'
    else:
        output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-{atlas_name}-signals{psc_fn}.tsv'
    print(output_fn)
    if os.path.exists(output_fn) and not overwrite:
        print(f'{output_fn} already run, loading previous result...')
        return pd.read_csv(output_fn, sep='\t')
    
    # load & reshape
    epi_flat = epi.get_fdata().reshape((np.product(epi.shape[:3]), epi.shape[-1]))

    dfs = []
    for i in np.arange(len(atlas.labels)):
        print('.', end='')
        label = list(atlas.labels)[i]
        mask = image.index_img(atlas.maps, i)
        mask_flat = mask.get_fdata().ravel()
        signal = pd.DataFrame(np.average(epi_flat, weights=mask_flat, axis=0), columns=[label])
        signal.index.name = 'volume'
        dfs.append(signal)

    df = pd.concat(dfs, axis=1)
#     output_fn = f'../derivatives/extracted_signals/sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_run-{run}_desc-MASSP-signals.tsv'
    if not os.path.exists(os.path.dirname(output_fn)):
        os.makedirs(os.path.dirname(output_fn))
    df.to_csv(output_fn, sep='\t')
    print(output_fn)
    return df

In [None]:
# find all available functional runs, extract sub/ses/task/run info
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-*/func/*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
# all_combs = [x for x in all_combs if x[0] in ['002','003','004','005','006','007','008','009','010','011']]
# all_combs = [x for x in all_combs if x[0] in ['012','013','014','015','016','017','018','019','020','021','022','023','024','025','026']]
#all_combs = [x for x in all_combs if x[0] in ['027','029','030','031','032']]

# all_combs = [x for x in all_combs if x[1] == 'rlsat']
all_combs = [x for x in all_combs if x[1] == 'sstmsit']
# check if 
all_combs

In [None]:
# find all available functional runs, extract sub/ses/task/run info
all_runs = sorted(glob.glob('../derivatives/fmriprep/fmriprep/sub-*/ses-*/func/*space-T1w*_bold.nii.gz'))
regex = re.compile('.*sub-(?P<sub>\d+)_ses-(?P<ses>\S+)_task-(?P<task>\S+)_run-(?P<run>\d)_space-T1w*')
all_combs = [tuple(regex.match(x).groupdict().values()) for x in all_runs]
# all_combs = [x for x in all_combs if x[0] in ['002','003','004','005','006','007','008','009','010','011']]
# all_combs = [x for x in all_combs if x[0] in ['012','013','014','015','016','017','018','019','020','021','022','023','024','025','026']]
#all_combs = [x for x in all_combs if x[0] in ['027','029','030','031','032']]

# all_combs = [x for x in all_combs if x[1] == 'rlsat']
# all_combs = [x for x in all_combs if x[1] == 'sstmsit']
# check if 
all_combs

In [None]:
def check_affines(sub):
    sub = str(sub).zfill(3)
    all_funcs = sorted(glob.glob(f'../derivatives/fmriprep/fmriprep/sub-{sub}/ses*/func/sub*_space-T1w_desc-preproc_bold.nii.gz'))
    all_affines = [nib.load(x).affine for x in all_funcs]
    return (np.array(all_affines)[0] == np.array(all_affines)).all()

In [None]:
for i, comb in enumerate(all_combs):
    print(comb)
    sub = comb[0]
    if check_affines(sub):
        do_extract(comb, atlas_name='ATAG',overwrite=False, to_psc=False, use_hp=False)
    else:
        print(f'Affines for sub {sub} not identical')