# fMRI Pre-Processing Notebook

fMRI data pre-processing

### Import modules

In [None]:
from os.path import join as opj

from nipype import Node, Workflow, Function
from nipype.algorithms.misc import Gzip, Gunzip
from nipype.interfaces.io import SelectFiles, DataSink
from nipype.interfaces.utility import IdentityInterface
from nipype.interfaces import spm, fsl
from nipype.algorithms.rapidart import ArtifactDetect

from bids.layout import BIDSLayout

### Main parameters

In [None]:
INPT_PTH = r".../data/ds003548"  # need to be set
RSLT_PTH = r".../data/preprocessed"  # need to be set
SINK_DIR = "datasink"

SUBJECT_LIST = ['01']#, '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16']
TASK_LIST = ['emotionalfaces']
RUN_LIST = ['1']#, '2', '3', '4', '5']

FWHM = 8

### Measurement information collection

In [None]:
from copy import deepcopy

def collect_subjects_info(root_dir: str, subjects: list = None, tasks: list = None, runs: list = None) -> dict:
    layout = BIDSLayout(root_dir)

    subjects = layout.get_subjects() if subjects is None else subjects
    tasks = layout.get_tasks() if tasks is None else tasks
    runs = layout.get_runs() if runs is None else runs

    subjects = subjects if len(subjects) > 0 else [None]
    tasks = tasks if len(tasks) > 0 else [None]
    runs = runs if len(runs) > 0 else [None]

    info_dict = {}
    root_pth = [root_dir]
    file_name = ''
    for subject in subjects:
        subject_pth = deepcopy(root_pth)
        subject_name = deepcopy(file_name)
        if subject is not None:
            subject_pth += [f'sub-{subject}']
            subject_name += f'sub-{subject}'
        subject_pth += ['func']

        for task in tasks:
            task_name = deepcopy(subject_name)
            if task is not None:
                task_name += f'_task-{task}'

            for run in runs:
                run_name = deepcopy(task_name)
                if run is not None:
                    run_name += f'_run-{run}'
                
                run_name += '_bold.nii.gz'
                file_pth = subject_pth + [run_name]

                metadata = layout.get_metadata(opj(*file_pth))
                info_dict[opj(*file_pth)] = dict(
                    TR=metadata['RepetitionTime'],
                    SliceTiming=metadata['SliceTiming'],
                    SliceOrder=[metadata['SliceTiming'].index(x) for x in sorted(metadata['SliceTiming'])],
                    SliceNum=len(metadata['SliceTiming'])
                )

    return info_dict

### Pre-Processing Nodes

#### Functional Image Pre-Processing

In [None]:
# Motion Correction
motion_corr = Node(
    fsl.MCFLIRT(
        mean_vol=True,
        save_plots=True,
        interpolation='spline',
        output_type='NIFTI'
    ),
    name='motion_corr'
)

# Slice Timing Correction
# TODO: set arguments later -> ref_slice, num_slices, slice_order, time_acquisition, time_repetition
slice_timing = Node(
    spm.SliceTiming(
        time_repetition=2.,
        num_slices=35,
        ref_slice=17,
        time_acquisition=2.-2./35.
    ),
    name='slice_timing'
)

# Boundary-Based Coregistration
coreg_bbr_mtx = Node(
    fsl.FLIRT(
        dof=6,
        cost='bbr',
        schedule='/usr/local/fsl/etc/flirtsch/bbr.sch',
        output_type='NIFTI_GZ'
    ),
    name='coreg_pre'
)

coreg_warp = Node(
    fsl.FLIRT(
        interp='spline',
        apply_isoxfm=4,
        #apply_xfm=True,
        output_type='NIFTI'
    ),
    name='coreg_warp'
)

smooth = Node(
    fsl.Smooth(
        fwhm=FWHM,
        output_type='NIFTI_GZ'
    ),
    name='smooth'
)

#### Anatomical Image Pre-processing

In [None]:
tiss_template = r'/usr/local/spm/spm12/tpm/TPM.nii'
norm_template = tiss_template

# BET - Skullstrip anatomical Image
skull_strip = Node(
    fsl.BET(
        frac=0.5,
        robust=True,
        output_type='NIFTI'
    ),
    name="skull_strip"
)

# Segmentation
GM   = ((tiss_template, 1), 2, (True,True),   (False, False))
WM   = ((tiss_template, 2), 2, (True,True),   (False, False))
CSF  = ((tiss_template, 3), 2, (True,False),  (False, False))
BONE = ((tiss_template, 4), 2, (False,False), (False, False))
SOFT = ((tiss_template, 5), 2, (False,False), (False, False))
AIR  = ((tiss_template, 6), 2, (False,False), (False, False))
segmentation = Node(
    spm.NewSegment(
        tissues=[GM, WM, CSF, BONE, SOFT]
    ),
    name='segmentation'
)

# Threshold (- Threshold WM probability image)
threshold_WM = Node(
    fsl.Threshold(
        thresh=0.5,
        args='-bin',
        output_type='NIFTI_GZ'
    ),
    name="threshold_WM"
)

# Spatial Normalization
normalize = Node(
    spm.Normalize12(
        tpm=norm_template,
    ),
    name='normalize'
)

### Data IO Handling

In [None]:
def get_io_objects(run_list, subject_list, task_list):
    infosource=Node(IdentityInterface(fields=['subject_id','task_id','zrun_id']), 
            name="infosource")
    infosource.iterables=[('subject_id',subject_list),
                    ('task_id',task_list),
                    ('zrun_id',run_list)]

    # String template with {}-based strings
    templates = {'anat': 'sub-{subject_id}/anat/'
                    'sub-{subject_id}_T1w.nii.gz',
                    'func': 'sub-{subject_id}/func/'
                    'sub-{subject_id}_task-{task_id}_run-{zrun_id}_bold.nii.gz'}

    # Create SelectFiles node
    selectfiles = Node(
            SelectFiles(templates,
                    base_directory=INPT_PTH,
                    sort_filelist=True),
            name='selectfiles'
    )

    # DataSink- creates output folder for important outputs
    _substitutions=[('_task_id_','/task-'),
            ('_subject_id_','sub-'),
            ('_zrun_id_','/run-'),
            ('_fwhm_','fwhm-'),
            ('_roi',''),
            ('_mcf',''),
            ('_st',''),
            ('_flirt',''),
            ('_smooth',''),
            ('.nii_mean_reg','_mean'),
            ('.nii.par','.par')]

    subjFolders=[(f'fwhm-{FWHM}/', f'fwhm-{FWHM}-')]
    _substitutions.extend(subjFolders)
    # _substitutions = []

    datasink=Node(
    DataSink(
            base_directory=RSLT_PTH,
            container=SINK_DIR,
            substitutions=_substitutions),
    name="datasink"
    )
    
    return infosource, selectfiles, datasink

### Main Process Workflow

In [None]:
def get_slice_timing(path_id, subject_info):
    return subject_info[path_id]['SliceTiming']
get_st_node = Node(
    Function(
        input_names=['path_id', 'subject_info'],
        output_names=['output'],
        function=get_slice_timing
    )
    ,name='get_st'
)

def get_slice_order(path_id, subject_info):
    return subject_info[path_id]['SliceOrder']
get_so_node = Node(
    Function(
        input_names=['path_id', 'subject_info'],
        output_names=['output'],
        function=get_slice_order
    )
    ,name='get_so'
)

In [None]:
import os
import shutil

run_list = RUN_LIST
# subject_list = SUBJECT_LIST
task_list = TASK_LIST

for subject in SUBJECT_LIST:
    subject_list = [subject]
    infosource, selectfiles, datasink = get_io_objects(run_list, subject_list, task_list)
    subjectinfo = collect_subjects_info(INPT_PTH, subject_list, task_list, run_list)


    def get_WM(data):
        return data[1][0]

    get_so_node.inputs.subject_info = subjectinfo

    input_tuple = (infosource, selectfiles, [('subject_id','subject_id'),
                                                ('task_id','task_id'),
                                                ('zrun_id','zrun_id')])

    preproc = Workflow(name='preprocessing', base_dir=opj(RSLT_PTH, "workingdir"))
    preproc.connect([
        input_tuple,
        # Anatomical part
        (selectfiles, skull_strip, [('anat', 'in_file')]),
        (skull_strip, segmentation, [('out_file', 'channel_files')]),
        (segmentation, threshold_WM, [(('native_class_images', get_WM), 'in_file')]),

        # Functional part
        (selectfiles, motion_corr, [('func', 'in_file')]),
        (motion_corr, slice_timing, [('out_file', 'in_files')]),
        (selectfiles, get_so_node, [('func', 'path_id')]),
        (get_so_node, slice_timing, [('output', 'slice_order')]),

        # Coregistration
        (threshold_WM, coreg_bbr_mtx, [('out_file', 'wm_seg')]),
        (skull_strip, coreg_bbr_mtx, [('out_file', 'reference')]),
        (motion_corr, coreg_bbr_mtx, [('mean_img', 'in_file')]),

        (coreg_bbr_mtx, coreg_warp, [('out_matrix_file', 'in_matrix_file')]),
        (skull_strip, coreg_warp, [('out_file', 'reference')]),
        (slice_timing, coreg_warp, [('timecorrected_files', 'in_file')]),

        # MNI normalization
        (coreg_warp, normalize, [('out_file', 'apply_to_files')]),
        (skull_strip, normalize, [('out_file', 'image_to_align')]),

        # Smoothing
        (normalize, smooth, [('normalized_files', 'in_file')]),

        # Save data
        # (motion_corr, datasink, [('out_file', 'preproc.@motion_corrected')]),
        # (motion_corr, datasink, [('mean_img', 'preproc.@motion_mean_img')]),
        # (slice_timing, datasink, [('timecorrected_files', 'preproc.@slice_time_corrected')]),

        # (skull_strip, datasink, [('out_file', ('preproc.@skull_stripped'))]),
        # (segmentation, datasink, [('native_class_images', 'preproc.@segmented')]),
        # (threshold_WM, datasink, [('out_file', 'preproc.@thresholded_WM')]),

        # (coreg_bbr_mtx, datasink, [('out_matrix_file', 'preproc.@coreg_matrix')]),
        # (coreg_warp, datasink, [('out_file', 'preproc.@coregistered')]),

        # (normalize, datasink, [('normalized_image', 'preproc.@normalized_anat')]),
        (normalize, datasink, [('normalized_files', 'preproc.@normalized_func')]),
        (smooth, datasink, [('smoothed_file', 'preproc.@smoothed')])
    ])

    preproc.run('MultiProc', plugin_args={'n_procs': 3})

    # root = opj(RSLT_PTH, "workingdir", "preprocessing")
    # for fname in os.listdir(root):
    #     if "subject_id" in fname:
    #         shutil.rmtree(opj(root, fname), ignore_errors=True)