# **Demo Script 2**


This script is the testing portion of the trained model analysis. Specifically this script loads in the test fNIRS and fMRI data, as well as the trained model (from demo 1). The trained model is then run on the testing data and the model performance is found.


In [19]:
import scipy.io as sio
import pickle

import numpy as np
from scipy.stats import pearsonr
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from nltools.stats import phase_randomize
from statsmodels.stats.multitest import multipletests

## **Test the model on all of Run 2 Data**

In [20]:
## Test a model on each subject in Run 2

# Load fNIRS test data
fnirs_file2 = '../data/fnirs/clip2_noUncertain_zhbo_TxRxS.mat'
test_fNIRS = sio.loadmat(fnirs_file2)['zhbo_TxRxS'].transpose((2, 0, 1))  # reshape from TxRxS to SxTxR
# get shape
print(f'test_fNIRS shape is {test_fNIRS.shape}')

# let's impute missing channels with mean of all other channels
test_fNIRS = np.array(
    [
        np.apply_along_axis(lambda x: np.nan_to_num(x, nan=np.nanmean(x)), axis=1, arr=curr_region) \
            for curr_region in test_fNIRS.transpose((2, 1, 0))  # transpose to shape: RxTxS
    ]
).transpose((2, 1, 0))

# load fMRI test data
fmri_file2 = '../data/fmri/clip2_undenoised_withcartoon_fillmean_TxRxS.mat'
test_fMRI = sio.loadmat(fmri_file2)['bold_TxRxS'].transpose((2, 0, 1))  # reshape from TxRxS to SxTxR
test_fMRI_mean = np.mean(test_fMRI, axis=0)  # already imputed, so we use np.mean instead of np.nanmean
# get shape
print(f'test_fMRI shape is {test_fMRI.shape}')

# Load the model
file_path = '../models/sherlock_run1model.pickle'
with open(file_path, 'rb') as f:
   X_pca, y_pca, lr_reg = pickle.load(f)

# get pc loadings for X and y
X_pca_loadings = X_pca.components_
y_pca_loadings = y_pca.components_

print("Model loaded...")


test_fNIRS shape is (29, 1030, 20)
test_fMRI shape is (17, 1030, 122)
Model loaded...


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [21]:
model_performance = []
model_predicted_tc = []
all_predicted_tc = []

for subj in range(test_fNIRS.shape[0]):
  curr_test_fNIRS = test_fNIRS[subj,:,:]
  fNIRS_test_pc = X_pca.transform(curr_test_fNIRS)
  model_predicted_pc = lr_reg.predict(X=fNIRS_test_pc)
  model_predicted_tc = y_pca.inverse_transform(model_predicted_pc)
  all_predicted_tc.append(model_predicted_tc)  # Append current prediction to the list

all_predicted_tc = np.stack(all_predicted_tc, axis=0)  # SxTxR
mean_predicted_tc = np.mean(all_predicted_tc, axis=0)  # TxR

In [22]:
r_values = np.zeros(test_fMRI_mean.shape[1])

for region in range(test_fMRI_mean.shape[1]):
    # Compute the Pearson correlation for each brain region
    r_values[region], _ = pearsonr(mean_predicted_tc[:, region], test_fMRI_mean[:, region])

median_model_performance = np.median(r_values)

print(median_model_performance)

0.2092278266068456


## **Permutation Test**

In [23]:
# set number of iterations
nIter = 1000


# initialize var for storing null distrib and pval
null_grp_grp_corr_RxN = []
grp_grp_corr_pval_Rx1 = []


# run permutation
true_RxT = test_fMRI_mean.T
pred_RxT = mean_predicted_tc.T
for roi in range(true_RxT.shape[0]):

    print(f'ROI {roi+1}: ')

    # get TC for current roi
    true_T = true_RxT[roi]   # 1d
    m1_pred_T = pred_RxT[roi]   # 1d

    # initialize var for storing null distrib of curr roi
    curr_null_grp_grp_corr_N = []

    # run n iters
    for i in range(nIter):

        if i % 100 == 0:
            print(f'    iter {i+1}')

        # phase rand true TC
        permuted_true_T = phase_randomize(true_T, random_state=None)

        # T [corr] T => scalar
        m1_null_grp_grp_corr = pearsonr(m1_pred_T, permuted_true_T)[0]
        curr_null_grp_grp_corr_N.append(m1_null_grp_grp_corr)

        
    # append curr roi null distrib to 2d array RxN and pval to 1d array Rx1
    null_grp_grp_corr_RxN.append(curr_null_grp_grp_corr_N)
    grp_grp_corr_pval_Rx1.append(
        (np.sum(curr_null_grp_grp_corr_N >= r_values[roi]) + 1) / (len(curr_null_grp_grp_corr_N) + 1)
    )


# convert list to np array
null_grp_grp_corr_RxN = np.array(null_grp_grp_corr_RxN)
grp_grp_corr_pval_Rx1 = np.array(grp_grp_corr_pval_Rx1)

ROI 1: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 2: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 3: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 4: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 5: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 6: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 7: 
    iter 1
    iter 101
    iter 201
    iter 301
    iter 401
    iter 501
    iter 601
    iter 701
    iter 801
    iter 901
ROI 8: 
    iter 1
    iter 101
    iter 201
   

In [24]:
# FDR correction

# get significance mask for rois, corrected pval
sigmask, pval_corrected, _, _ = multipletests(grp_grp_corr_pval_Rx1, 0.05, 'fdr_bh')

# get a list of significant rois
sigroi = np.where(pval_corrected < 0.05)[0] + 1
print(f'A total of {len(sigroi)} significant ROIs.\nIndices: {sigroi}')

# mask corr array: set non-sig roi's corr to 0
grp_grp_corr_pmasked_Rx1 = np.squeeze(r_values) * sigmask

A total of 93 significant ROIs.
Indices: [  4   5   6  12  13  14  15  16  17  18  20  21  22  23  24  25  26  27
  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  45  46
  47  48  49  50  51  52  53  54  57  59  60  61  62  68  69  70  71  72
  73  74  76  78  79  80  81  82  83  84  86  87  88  89  90  91  92  93
  95  96  97  98  99 100 102 103 104 105 106 107 108 109 110 111 112 114
 116 120 122]
