# 4. Permutation Test Results

Organizes permutation results
1. Average lassopcr searchlight results over cv iterations
2. Calculate the permutation distribution
3. Run permutation tests: compare lassopcr-sl results against the permutation distribution
4. Save the mask for the significant voxels.

*Yiyu Wang 2022 Jan*


In [None]:
import glob
import os
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt

# stats
from scipy import linalg, ndimage, stats
from scipy.stats import norm

# nifti handling
from nilearn.input_data import NiftiMasker
import nilearn.masking as masking
from nilearn.masking import apply_mask

from nilearn import image
from nilearn.image import new_img_like, load_img, get_data, concat_imgs, mean_img, math_img, index_img,threshold_img
from nilearn.reporting import get_clusters_table
from nilearn.glm import threshold_stats_img

# plotting modules
from nilearn import plotting
from nilearn.plotting import plot_stat_map, plot_img, show

# nilearn mask:
from nilearn.datasets import load_mni152_gm_mask,load_mni152_wm_mask

import warnings
warnings.filterwarnings('ignore')

In [None]:
bg_img = 'masks/FSL_MNI152_T1_3mm_brain.nii.gz'

#load behavioral data
behavdata_dir =  'BehavData/'
zratings =glob.glob(behavdata_dir +'AffVids_novel_interpolated_rating_zscored.csv')
zratings = pd.read_csv(zratings[0],index_col=0).reset_index()

my_radius = 15

# subjects information
subjects_str = ['04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','23','25','26','28','29'] 
subjects = list(range(4,20))+[23,25,26,28,29]
Nsub = len(subjects)
print("subjects in this analysis:")
print(subjects_str)
print(f"**** n = {Nsub} *****" )

k_fold = 3
my_radius = 15
cluster_thre = 1
my_slices = [-24, -12, 0, 12, 21, 30, 42, 54]

# stats alpha
alpha = 0.01
z_thre = stats.norm.ppf(1-alpha/2)
print(z_thre)

VMAX = 0.45

res_dir = 'results/searchlight_lassopcr/'
avg_dir = 'results/searchlight_wholebrain/'
permutation_dir = 'results/permutation/'
permutation_test_res_dir = 'results/permutation_test/'

if not os.path.isdir(permutation_test_res_dir):
    os.mkdir(permutation_test_res_dir)
if not os.path.isdir(avg_dir):
    os.mkdir(avg_dir)
        
        

In [None]:
gm_mask = load_mni152_gm_mask(resolution=3, threshold=0.2, n_iter=2)
mask = math_img('img > 0.2', img=gm_mask)
plotting.view_img(mask, threshold=None)


In [None]:
def compare_against_permutation(permutation_img, lassopcr_img, mask, n_perm, prob_thr = 0.05, cluster_thre=15):  
    # load mask for all the functional voxels
    masker = NiftiMasker(mask_img=mask, mask_strategy='epi',standardize=False)
    
    # extract data from mask
    masked_permutation_scores = masker.fit_transform(permutation_img)
    masked_lassopcr_scores = masker.fit_transform(lassopcr_img)
    
    # pval map
    n_func_voxel = np.shape(masked_permutation_scores)[1]
    masked_significance = np.zeros(n_func_voxel)
    
    pvals_map = np.zeros(n_func_voxel)
    prob_thr = prob_thr
    for i in range(n_func_voxel):
        prob = np.sum(masked_permutation_scores[:,i] >= masked_lassopcr_scores[:,i])/n_perm
        pvals_map[i] = prob
        if prob < prob_thr:
            masked_significance[i] = 1
            

    # extract r-values for significant voxels:
    masked_significant_scores = np.zeros(n_func_voxel)
    masked_significant_scores[np.where(masked_significance==1)] = masked_lassopcr_scores[0][np.where(masked_significance==1)]
    significance_map_img = masker.inverse_transform(masked_significant_scores)

    significance_map_cluster_thr = threshold_stats_img(significance_map_img, mask_img= mask, threshold= 0,cluster_threshold=cluster_thre, height_control=None)
    unc_significance_map_img = new_img_like(mask, get_data(significance_map_cluster_thr[0]), affine=mask.affine)
    
    # fwe p correction
    fwe_threshold = family_wise_error_permutation(permutation_img, mask, prob_thr=prob_thr)
    
    
    if len(np.shape(lassopcr_img))==4: #if has a forth dimension, the forth dimension should be 1
        if np.shape(lassopcr_img)[3] !=1 :
            raise Exception("Check LassoPCR image 4th dimension!!")
        else:
            fwe_corr_significance_map_img = index_img(threshold_img(lassopcr_img, threshold = fwe_threshold, copy = True),0)
            
    
    elif len(np.shape(lassopcr_img))==3:
        fwe_corr_significance_map_img = threshold_img(lassopcr_img, threshold = fwe_threshold, copy = True)
        
    else:
        raise Exception('ERROR in lassopcr_img dimension!')
    
    
    return unc_significance_map_img, fwe_threshold, fwe_corr_significance_map_img


In [None]:
def family_wise_error_permutation(permutation_img, mask,prob_thr):
    
    masker = NiftiMasker(mask_img=mask, mask_strategy='epi',standardize=False)
    masked_permutation_scores = masker.fit_transform(permutation_img)
    fwe = np.percentile(masked_permutation_scores, (1-prob_thr)*100)
    
    return fwe

def fdr_permutation(pvals_map, masked_lassopcr_scores, prob_thr=0.05): 
    masked_lassopcr_scores_sorted = np.sort(masked_lassopcr_scores, axis=None)
    n_samples = np.shape(masked_lassopcr_scores)[1]
    pos = pvals_map < prob_thr * np.linspace(
        .5 / n_samples, 1 - .5 / n_samples, n_samples)
    if pos.any():
        fdr =  np.ravel(masked_lassopcr_scores_sorted)[pos][-1] - 1.e-12
    else:
        fdr = 1
    
    return fdr

In [None]:
def get_lasso(train_cat, test_cat, score_name):
    lasso_file_path = avg_dir + f'kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}_avg.nii.gz'
    lassopcr_file_list = glob.glob(lasso_file_path)
    lassopcr_img = concat_imgs(lassopcr_file_list)

    return lassopcr_img

def get_permutation(train_cat, test_cat, score_name, save=True):
    
    lassopcr_file_list = glob.glob(avg_dir + f'kfold3_searchlight_pearsonr_train_{train_cat}_test_{test_cat}_avg.nii.gz')
    permut_file_path = permutation_dir + f'cv*_kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}_cv_*.nii.gz'
    permutation_file_list = glob.glob(permut_file_path)

    permutation_file_list.append(lassopcr_file_list) # include scores in permutation distribution
    permutation_img = concat_imgs(permutation_file_list)
    
    print("number of permutations: ",len(permutation_file_list))
    if save:
        nib.save(permutation_img, permutation_test_res_dir + f'train_{train_cat}_test_{test_cat}_permutation_distribution_{score_name}.nii.gz')
    
    n_perm = len(permutation_file_list)
    
    return permutation_img, n_perm
    

In [None]:
# whether calculate the average of cv:
average_cv = False
# whether run permutation:
calculate = True

# Situation General Model:

In [None]:
# calculate the average 

if average_cv:
    train_cat_list = ['Situation_General']
    test_cat_list = ['Situation_General','Heights','Social','Spiders']


    # calculate average images
    score_list = ['rmse','pearsonr']
    for score_name in score_list:
        for train_cat in train_cat_list:
            for test_cat in test_cat_list:

                file_list = glob.glob(res_dir + f'cv*_kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}*.nii.gz')
                cv_avg_img = mean_img(file_list)
                
                nib.save(cv_avg_img, avg_dir + f'kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}_avg.nii.gz')

                # a quick visual check for the saved files
                plotting.plot_stat_map(cv_avg_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img,
                                                title= f'cv average image: SG-{test_cat}', colorbar=True)



In [None]:
# permutation test
train_cat_list = ['Situation_General']
test_cat_list = ['Situation_General','Heights','Social','Spiders']

score_list = ['pearsonr']


prob_thr = 0.05

for score_name in score_list:
    print(score_name)
    for train_cat in train_cat_list:
        for test_cat in test_cat_list:
            print("train: ", train_cat)
            print("test: ", test_cat)
            
            lassopcr_img = get_lasso(train_cat, test_cat, score_name)

            # load permutation distributions:
            permutation_img, n_perm = get_permutation(train_cat, test_cat, score_name)   
            print("number of permutations: ",n_perm)
            
            # run permutation test:
            unc_significance_map_img, fwe_threshold, fwe_corr_significance_map_img = compare_against_permutation(permutation_img,lassopcr_img,mask,n_perm, prob_thr=prob_thr)
            print("family wise correction threshold: ", fwe_threshold)
            
            nib.save(fwe_corr_significance_map_img, permutation_test_res_dir + f'train_{train_cat}_test_{test_cat}_significant_{score_name}_fwe.nii.gz')
            
            # quick visual check to make sure the image is not empty
            plotting.plot_stat_map(unc_significance_map_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img, title= f'uncorrected {score_name}: {train_cat} Searchlight lassopcr {test_cat} permutation test', colorbar=True, vmax = VMAX)
            plotting.plot_stat_map(fwe_corr_significance_map_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img, title= f'fwe_corrected {score_name}: {train_cat} Searchlight lassopcr {test_cat} permutation test', colorbar=True, vmax = VMAX)
            
            
            
            

## Situation Dependent Model: 

In [None]:
# calculate the average 

if average_cv:
    train_cat_list = ['Heights','Social','Spiders']

    # calculate average images
    score_list = ['rmse','pearsonr']
    for score_name in score_list:
        for train_cat in train_cat_list:
            test_cat = train_cat

            file_list = glob.glob(res_dir + f'cv*_kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}*.nii.gz')
            cv_avg_img = mean_img(file_list)
            
            nib.save(cv_avg_img,avg_dir + f'kfold3_searchlight_{score_name}_train_{train_cat}_test_{test_cat}_avg.nii.gz')

            # a quick visual check for the saved files
            plotting.plot_stat_map(cv_avg_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img,
                                            title= f'cv average image: SS-{test_cat}', colorbar=True)



In [None]:
# run permutation:
train_cat_list = ['Heights','Social','Spiders']


score_list = ['pearsonr']
    
prob_thr = 0.05

for score_name in score_list:
    print(score_name)
    for train_cat in train_cat_list:
        test_cat = train_cat
        print("train: ", train_cat)
        print("test: ", test_cat)


            lassopcr_img = get_lasso(train_cat, test_cat, score_name)

            # load permutation distributions:
            permutation_img, n_perm = get_permutation(train_cat, test_cat, score_name)   
            print("number of permutations: ",n_perm)
            
            # run permutation test:
            unc_significance_map_img, fwe_threshold, fwe_corr_significance_map_img = compare_against_permutation(permutation_img,lassopcr_img,mask,n_perm, prob_thr=prob_thr)
            print("family wise correction threshold: ", fwe_threshold)
            
    
            nib.save(fwe_corr_significance_map_img, permutation_test_res_dir + f'train_{train_cat}_test_{test_cat}_significant_{score_name}_fwe.nii.gz')
            
            # quick visual check to make sure the image is not empty
            plotting.plot_stat_map(unc_significance_map_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img, title= f'uncorrected {score_name}: {train_cat} Searchlight lassopcr {test_cat} permutation test', colorbar=True, vmax = VMAX)
            plotting.plot_stat_map(fwe_corr_significance_map_img, display_mode='z', cut_coords=my_slices,threshold=0.0,bg_img=bg_img, title= f'fwe_corrected {score_name}: {train_cat} Searchlight lassopcr {test_cat} permutation test', colorbar=True, vmax = VMAX)
            
            
            
            