In [None]:
'''
Author: Zala Reppmann, zala.reppmann@charite.de

Aim:
Analyze voxel-wise task effect of the adapted ScanSTRESS task by performing permutation-based statistical testing 
on first-level contrast images across participants.

Input:
- First-level contrast (feat) images for each participant from the adapted ScanSTRESS task.

Output:
- Log file documenting script execution and key results.
- Brain mask (.nii) created by combining the minimum mask and resampled MNI brain mask.
- Permutation test results:
  - T-values image (.nii)
  - Corrected p-values image (.nii) (FDR corrected)
  - Thresholded t-values image (.nii) (significance mask)
- Site-specific permutation test results (per site, saved in NeuroVault-compatible format).

Steps:
   - Initialize logging to track script execution.
   - Load first-level contrast images for all participants.
   - Exclude participants based on predefined criteria (e.g., equipment malfunction, flawed data).
   - Load and resample the MNI brain mask to match the study-specific minimum mask.
   - Combine both masks to ensure accurate spatial alignment.
   - Perform a voxel-wise permutation test (2000 permutations) to evaluate task effects.
   - Apply False Discovery Rate (FDR) correction to control for multiple comparisons.
   - Generate voxel-wise t-values, corrected p-values, and thresholded t-values images.
   - Threshold t-values based on corrected p-values (alpha = 0.05).
   - Save T-values, corrected p-values, and thresholded t-values as NIfTI images.

   additional site-wise task effect (not part of manuscript but images can be found on neurovault):
   - Perform site-specific permutation tests for each study site.
   - Save separate T-values and corrected p-values images for each site.

'''
# imports 
import logging
import numpy as np
import os
import nibabel as nib
import nilearn.image
from nilearn.image import load_img, concat_imgs, mean_img, new_img_like, resample_to_img
from nilearn import datasets
from nilearn.masking import apply_mask
from nilearn.input_data import NiftiMasker
from scipy.stats import ttest_1samp
from nilearn.mass_univariate import permuted_ols
from statsmodels.stats.multitest import multipletests

# clear existing log handlers
logging.getLogger('').handlers = []

#--------------
# setup logging
#--------------
logging.basicConfig(filename='/path/OBS_bayes_taskeffect_vox_logfile.log', level=logging.INFO, 
                    format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

#---------------------------
# define paths and load data
#---------------------------
# data
basepath = '/path/results_fmri'

# excluded IDs
excluded_participant_ids = [] # expects list of strings containing participant IDs that are to be excluded from the analysis

# filter directories
all_participant_dirs = os.listdir(basepath)
participant_ids = [pid for pid in all_participant_dirs if pid not in excluded_participant_ids]
logging.info(f"Number of participants included: {len(participant_ids)}")
logging.info(f"Participant IDs: {', '.join(participant_ids)}")

# path and file name temps
file_template = os.path.join(basepath, 'sub-{0}', 'func', 'task-stress', 'sub-{0}_task-stress_feature-taskStress1_taskcontrast-stress_stat-effect_statmap.nii.gz')

# load 1st level contrast images (skipping missing files)
image_files = [file_template.format(pid.replace('sub-', '')) for pid in participant_ids if 'sub-' in pid and os.path.exists(file_template.format(pid.replace('sub-', '')))]
if not image_files:
    raise ValueError("No valid image files found for the included participants.")

# concat 1st level contrast images
all_images = concat_imgs(image_files)
logging.info(f"All images concatenated")

#--------------------
# creating brain mask
#--------------------
# getting the minimum mask
min_mask = load_img("/path/minimum_mask_no_outliers.nii.gz")
print("min mask dimensions:", min_mask.shape)

# load MNI brain mask
MNI_mask = datasets.load_mni152_brain_mask(resolution=2)
print("MNI mask dimensions:", MNI_mask.shape)

# resample MNI mask to match min_mask ((99, 117, 95) -> (97, 115, 97))
resampled_MNI_mask = resample_to_img(MNI_mask, min_mask, interpolation='nearest')

min_mask_data = min_mask.get_fdata()
resampled_MNI_mask_data = resampled_MNI_mask.get_fdata()
mask_img = np.logical_and(min_mask_data > 0, resampled_MNI_mask_data > 0)

# convert mask img so nibabel can handle it
mask_img = new_img_like(min_mask, mask_img)
print("brain mask (mask_img) dimensions:", mask_img.shape)

logging.info(f"Brain mask created")

#-----------------------------------------------------
# Permutation test for test against 0 + FDR correction
#-----------------------------------------------------
# masking 4D data with NiftiMasker
nifti_masker = NiftiMasker(mask_img=mask_img)
masked_data = nifti_masker.fit_transform(all_images)

# prep design matrix for second-level analysis
# simple one-sample test, design matrix with a constant term for each participant
n_subjects = len(image_files)
design_matrix = np.ones((n_subjects, 1))

# run permuted OLS
n_perm = 2000
results = permuted_ols(design_matrix, masked_data, n_perm=n_perm, two_sided_test=True, n_jobs=1, output_type='dict')
logging.info(f"Permutation test completed")
logging.info(f"Results type: {type(results)}, Results contents: {results.keys()}")

# extract t-values and log p-values
t_values = results['t']
log_p_values = results['logp_max_t']

# convert log p-values to p-values
p_values = 10 ** -log_p_values

# flatten p-values for FDR correction
p_values_flat = p_values.flatten()

# FDR correction
_, p_values_corrected_flat, _, _ = multipletests(p_values_flat, alpha=0.05, method='fdr_bh')

# reshape corrected p-values back to original shape
p_values_corrected = p_values_corrected_flat.reshape(p_values.shape)

# convert t-values and corrected p-values back to image space
t_values_img = nifti_masker.inverse_transform(t_values)
p_values_corrected_img = nifti_masker.inverse_transform(p_values_corrected)

# threshold the t-values based on corrected p-values
thresholded_t_values = t_values.copy()
thresholded_t_values[p_values_corrected > 0.05] = 0
thresholded_t_values_img = nifti_masker.inverse_transform(thresholded_t_values)

# output for confirmation
print("T-values image dimensions:", t_values_img.shape)
print("Corrected p-values image dimensions:", p_values_corrected_img.shape)
print("Thresholded t-values image dimensions:", thresholded_t_values_img.shape)

# determine the absolute t-value threshold
significant_t_values = t_values[p_values_corrected <= 0.05]
if significant_t_values.size > 0:
    threshold_t_value = np.min(np.abs(significant_t_values))
else:
    threshold_t_value = None

# print the t-value threshold
if threshold_t_value is not None:
    print(f"The absolute t-value threshold for significance after FDR correction is: {threshold_t_value}")
else:
    print("No significant t-values found after FDR correction")

# save images
t_values_img.to_filename('/path/stress_effect_tvals_vox.nii')
p_values_corrected_img.to_filename('/path/stress_effect_pvals_corrected_vox.nii')
thresholded_t_values_img.to_filename('/path/stress_effect_thresholded_tvals_vox.nii')

logging.info(f"T-values, corrected p-values, and thresholded t-values images saved")

#----------------------
# Site-wise task effect # not reported in the manuscript but images can be found on neurovault
#----------------------
# Categorize participants by site
sites = {'1': 'Berlin', '2': 'Mainz', '3': 'Nijmegen', '4': 'TelAviv', '5': 'Warsaw'}
participants_by_site = {site: [] for site in sites.values()}

for pid in participant_ids:
    site_code = pid[5]
    site_name = sites.get(site_code)
    if site_name:
        participants_by_site[site_name].append(pid)

# log count per site
for site, pids in participants_by_site.items():
    logging.info(f"Number of participants at {site}: {len(pids)}")
    
    # load images for current site
    image_files = [file_template.format(pid.replace('sub-', '')) for pid in pids if os.path.exists(file_template.format(pid.replace('sub-', '')))]
    if not image_files:
        logging.info(f"No valid image files found for site {site}")
        continue
    
    all_images = concat_imgs(image_files)
    
    # masking data
    masked_data = nifti_masker.fit_transform(all_images)
    
    # Statistical analysis
    n_subjects = len(image_files)
    design_matrix = np.ones((n_subjects, 1))
    
    # run permuted OLS
    n_perm = 2000
    results = permuted_ols(design_matrix, masked_data, n_perm=n_perm, two_sided_test=True, n_jobs=1, output_type='dict')
    logging.info("Permutation test completed")
    logging.info(f"Results type: {type(results)}, Results contents: {results.keys()}")

    # extract t-vals, log p-vals
    t_values = results['t']
    log_p_values = results['logp_max_t']
    
    # convert log p-values to p-values
    p_values = 10 ** -log_p_values

    # flatten p-values for FDR correction
    p_values_flat = p_values.flatten()

    # FDR correction
    _, p_values_corrected_flat, _, _ = multipletests(p_values_flat, alpha=0.05, method='fdr_bh')

    # reshape corrected p-values back to original shape
    p_values_corrected = p_values_corrected_flat.reshape(p_values.shape)

    # convert t-values and corrected p-values back to image space
    t_values_img = nifti_masker.inverse_transform(t_values)
    p_values_corrected_img = nifti_masker.inverse_transform(p_values_corrected)

    # threshold t-values based on corrected p-values
    thresholded_t_values = t_values.copy()
    thresholded_t_values[p_values_corrected > 0.05] = 0
    thresholded_t_values_img = nifti_masker.inverse_transform(thresholded_t_values)

    # determine absolute t-val threshold
    significant_t_values = t_values[p_values_corrected <= 0.05]
    if significant_t_values.size > 0:
        threshold_t_value = np.min(np.abs(significant_t_values))
    else:
        threshold_t_value = None

    # print t-value threshold
    if threshold_t_value is not None:
        print(f"The absolute t-value threshold for significance after FDR correction for site {site} is: {threshold_t_value}")
    else:
        print(f"No significant t-values found after FDR correction for site {site}")

    # Save images
    t_output_path = f'/path/stress_effect_tvals_vox_{site}.nii'
    t_values_img.to_filename(t_output_path)
    
    p_output_path = f'/path/stress_effect_pvals_corrected_vox_{site}.nii'
    p_values_corrected_img.to_filename(p_output_path)
    
    thresholded_t_output_path = f'/path/stress_effect_thresholded_tvals_vox_{site}.nii'
    thresholded_t_values_img.to_filename(thresholded_t_output_path)

    logging.info(f"T-values, corrected p-values, and thresholded t-values images saved for {site}")

logging.info(f"Script executed")