In [None]:
%matplotlib inline

import os
import glob
import re
import fnmatch
import nilearn.plotting as nlp
import matplotlib.pyplot as plt
import nibabel as nb
import numpy as np
import pandas as pd
import seaborn as sns
import warnings

In [None]:
##### FMRI QA #####

In [None]:
### PROJECT-SPECIFIC VARIABLES TO CHANGE ###

openfmri_dir = '/om/user/satra/projects/SAD/data/'
l1_dir = '/om/user/satra/projects/SAD/fmri_results/individual/'
subj_prefix = 'SAD_???'
fs_dir = '/om/user/satra/projects/SAD/fsdata/'   # set to None if no freesurfer dir
skip_tasks = ['task%03d' % val for val in [2, 3, 4, 5, 6]]
outlier_threshold = 10    # percent

############################################

In [None]:
# Setting up models, tasks, subject lists, runs, # of vols, contrasts, etc.

models = [task.split('/')[-1] for task in sorted(glob.glob(os.path.join(openfmri_dir, 'models', 
                                                                        'model*')))]
task_key = os.path.join(openfmri_dir, 'task_key.txt')
with open(task_key, 'r') as f:
    tasks = [re.split(' |,\t',line)[0] for line in f]

#task_contrasts = os.path.join(openfmri_dir, 'models', model, 'task_contrasts.txt')
#condition_key = os.path.join(openfmri_dir, 'models', model, 'condition_key.txt')

def get_tasks(model):
    tasks = [task.split('/')[-1] for task in sorted(glob.glob(os.path.join(l1_dir, model, 'task*')))]
    return [task for task in tasks if task not in skip_tasks]


def get_subjlist(model, task):
    subjlist = [fl for fl in sorted(os.listdir(os.path.join(l1_dir, model, task))) 
                if fnmatch.fnmatch(fl,subj_prefix)]
    return sorted(subjlist)

def get_subjtask(model):
    df = pd.DataFrame()
    for task in get_tasks(model):
        subjects = get_subjlist(model, task)
        df[task] = pd.Series([1 for subj in subjects], index=subjects)
    return df

def get_runs(task, subj):
    task_runs = sorted(os.listdir(os.path.join(openfmri_dir, subj, 'BOLD')))
    num_runs = 0
    for run in task_runs:
        if run.split('_')[0] == task:
            num_runs += 1
    return num_runs

def get_num_vols(task, subj, run):
    bold_file = os.path.join(openfmri_dir,subj,'BOLD','%s_run%03d' % (task,run+1),'bold.nii.gz')
    bold_img = nb.load(bold_file)
    return bold_img.shape[3]

def get_contrasts(model, task):
    contrasts_dict = {}
    with open(task_contrasts, 'r') as f:
        contrasts = [re.split(' |,\t',line)[1] for line in f if re.split(' |,\t',line)[0]==task]
    with open(condition_key, 'r') as f:
        cond_gt_rest = [re.split(' |,\t',line)[2].strip() + '_gt_rest' for line in f \
                        if re.split(' |,\t',line)[0]==task]
    contrasts.extend(cond_gt_rest)
    return contrasts

In [None]:
# REGISTRATION

def plot_mean2anat_overlay(subj, subj_path, fs_dir, title, fig=None, ax=None):   
    if fs_dir == None:
        bg = os.path.join(openfmri_dir, subj, 'anatomy', 'T1_001.nii.gz')
    else:
        bg = os.path.join(fs_dir, subj, 'mri', 'T1.mgz')
    mask = os.path.join(os.path.join(subj_path,'qa','mask','mean2anat',\
                                     'median_brain_mask.nii.gz'))
    close = False
    if ax is None:
        fig = plt.figure(figsize=(5, 2))
        ax = fig.gca()
        close = True
    display = nlp.plot_roi(roi_img=mask, bg_img=bg, black_bg=False, alpha=0.3, 
                           draw_cross=False, annotate=False,
                           figure=fig, axes=ax, title=title)
    if close:
        plt.show()
        display.close()

def plot_anat(subj_path, title, fig=None, ax=None):
    anat2target = os.path.join(subj_path, 'qa', 'anat2target', 'output_warped_image.nii.gz')
    close = False
    if ax is None:
        fig = plt.figure(figsize=(5, 2))
        ax = fig.gca()
        close = True
    display = nlp.plot_anat(anat_img=anat2target, cut_coords=(0, -13, 20), annotate=False,
                            draw_cross=False, title=title, figure=fig, axes=ax)
    if close:
        plt.show()
        display.close()
    
#def dice_coefficient():

In [None]:
# check registration of mean (median) functional to structural

for model in models:
    print '********** %s **********' % model
    subj_df = get_subjtask(model)
    for subj in subj_df.index:
        fig, ax = plt.subplots(1, subj_df.shape[1], figsize=(4 * subj_df.shape[1], 2))
        if not isinstance(ax, np.ndarray):
            ax = [ax]
        for idx, task in enumerate(subj_df.ix[subj].index):
            if subj_df.ix[subj, task]:
                plot_mean2anat_overlay(subj, os.path.join(l1_dir, model, task, subj), 
                                       fs_dir, title=subj + '-' + task, fig=fig, ax=ax[idx])
        plt.show()
        plt.close(fig)

In [None]:
coregistration_outliers = ['SAD_049', 'SAD_P42']

In [None]:
# check registration of structural to MNI template

for model in models:
    print '********** %s **********' % model
    subj_df = get_subjtask(model)
    for subj in subj_df.index:
        fig, ax = plt.subplots(1, subj_df.shape[1], figsize=(4 * subj_df.shape[1], 2))
        if not isinstance(ax, np.ndarray):
            ax = [ax]
        for idx, task in enumerate(subj_df.ix[subj].index):
            if subj_df.ix[subj, task]:
                plot_anat(os.path.join(l1_dir, model, task, subj), title=subj + '-' + task, fig=fig, ax=ax[idx])
        plt.show()
        plt.close(fig)

In [None]:
normalization_outliers = ['SAD_049']

In [None]:
# OUTLIERS

def count_outliers(subj_path, run):
    outlier_file = os.path.join(subj_path, 'qa', 'art', 
                                'run%02d_art.bold_dtype_mcf_outliers.txt' % (run + 1))
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        outliers = np.genfromtxt(outlier_file)
    return np.prod(outliers.shape)

In [None]:
# outputs number of outliers in each run, as well as a warning if outlier threshold is exceeded
outlier_dfs = {}
for model in models:
    df = pd.DataFrame()
    print '********** %s **********' % model
    tasks = get_tasks(model)
    for task in tasks:
        print '********** %s **********' % task
        subjects = get_subjlist(model, task)
        max_run = 0
        subj_runs = {}
        for subj in subjects:
            num_runs = get_runs(task, subj)
            subj_runs[subj] = num_runs
            if num_runs > max_run:
                max_run = num_runs
        task_info = np.zeros((len(subjects), max_run))
        columns = ['%s-Run%02d' % (task, run) for run in range(1, max_run + 1)]
        for idx, subj in enumerate(subjects):
            for run in range(subj_runs[subj]):
                num_outliers = count_outliers(os.path.join(l1_dir, model, task, subj), run) 
                num_vols = get_num_vols(task,subj,run)
                task_info[idx, run] = float(num_outliers)/num_vols
        df_task = pd.DataFrame(task_info, index=subjects, columns=columns)
        df = pd.concat((df, df_task), axis=1)
    sns.set(context="poster", font="monospace")
    f, ax = plt.subplots(figsize=(20, 20))
    sns.heatmap(df, linewidths=0, annot=True, vmax=0.1, vmin=0), #vmax=1, vmin=-1, 
    f.tight_layout()
    outlier_dfs[model] = df

In [None]:
outlier_names = {}
for model in models:
    outlier_names[model] = outlier_dfs[model].index[np.nonzero((outlier_dfs[model] > 0.1).sum(axis=1))[0]].tolist()
outlier_names

In [None]:
# MASKS

def plot_mask(subj_path,run):
    mask = os.path.join(subj_path,'qa','mask','run%02d_mask.nii.gz' % (run+1))
    bg = os.path.join(subj_path, 'mean', 'median.nii.gz')

    display = nlp.plot_roi(roi_img=mask, bg_img=bg, black_bg=False, alpha=0.3, display_mode='y', cut_coords=15)
    plt.show()
    display.close()

In [None]:
# quick visual check for holes in masks

for model in models:
    print '********** %s **********' % model
    tasks = get_tasks(model)
    print tasks
    for task in tasks:
        print '********** %s **********' % task
        for subj in get_subjlist(model, task):
            print subj
            num_runs = get_runs(task, subj)
            for run in range(num_runs):
                print 'run%02d' % (run + 1)
                plot_mask(os.path.join(l1_dir,model,task,subj),run)

In [None]:
# ZSTATS

def plot_stat_map(subj_path, contrast_num, title):
    zstat = os.path.join(subj_path,'zstats','mni','zstat%02d.nii.gz' % (contrast_num + 1))
    fig = plt.figure(figsize=(10, 1.5))
    display = nlp.plot_stat_map(stat_map_img=zstat, display_mode='z', threshold=2.3, figure=fig,
                                black_bg=True, cut_coords=np.linspace(-40, 70, 12), title=title,
                                annotate=False)
    plt.show()
    display.close()

In [None]:
# visual check of zstats in MNI space (displays slices in axial view)

for model in models:
    print '********** %s **********' % model 
    for task in get_tasks(model):
        print '********** %s **********' % task
        contrasts = get_contrasts(model,task)
        for contrast_num in range(len(contrasts)):
            print contrasts[contrast_num]
            for subj in get_subjlist(model, task):
                #print subj
                plot_stat_map(os.path.join(l1_dir, model, task, subj), contrast_num, title=subj)

In [None]:
# TSNR (from Satra's script)

def read_stats(fname):
    statsname = fname.split('_aparc')[0] + '_summary.stats'
    roi = np.genfromtxt(statsname, dtype=object)[:, 4]
    data = np.genfromtxt(fname)
    return dict(zip(roi, data))

import re

def get_data(fl, subj_regex):
    data = None
    for i, name in enumerate(fl):
        subjid = re.search(subj_regex, name).group()
        df = pd.DataFrame(read_stats(name), index=[subjid])
        if data is None:
            data = df
        else:
            data = pd.concat((data, df))
    return data.dropna(axis=1)

In [None]:
df = pd.DataFrame()
for task in get_tasks('model001'):
    print task
    fl = sorted(glob.glob(os.path.join(l1_dir, 'model001', task, subj_prefix,
                                             'qa','tsnr','run01_aparc+aseg_warped_avgwf.txt')))
    data = get_data(fl, 'SAD_...')
    idx = np.nonzero(data.median(axis=0) > 20)[0]
    data_trimmed = data.ix[:, idx]
    df_task = pd.Series((np.corrcoef(data) < 0.8).sum(axis=0), index=data.index, name=task)
    df = pd.concat((df, df_task), axis=1)
df = df/df.shape[0]    
sns.set(context="poster", font="monospace")
f, ax = plt.subplots(figsize=(20, 20))
sns.heatmap(df, linewidths=0, annot=True), #vmax=1, vmin=-1, 
f.tight_layout()

In [None]:
plt.hist(df.values.flatten(), 30)
plt.xlabel('% TSNR dissimilarity')
ph = plt.ylabel('Number of participants')

In [None]:
tsnr_outliers = df.index[np.nonzero((df >= .4).sum(axis=1))[0]].tolist()
tsnr_outliers

In [None]:
np.intersect1d(outlier_names['model001'], tsnr_outliers)