### Spin Test 
This script projects data to surface space and performs a spin test

#### Import Packages

In [74]:
import numpy as np
import pandas as pd
import netneurotools.stats as stats
from neuromaps.datasets import fetch_atlas
from nilearn.datasets import fetch_atlas_schaefer_2018
from netneurotools.datasets import fetch_schaefer2018
import neuromaps
import nibabel as nib
from neuromaps.parcellate import Parcellater
from neuromaps import transforms
from neuromaps import nulls
from scipy.stats import pearsonr, spearmanr
from nilearn.plotting import plot_surf
from nilearn.surface import load_surf_data
from nilearn import surface, plotting, datasets
from neuromaps import datasets, images, nulls, resampling, stats

#### Set Paths and Variables

In [83]:
dataset = 'longglucest_outputmeasures2'
atlas = 'Schaefer2018_1000Parcels_17Networks'
nmaps = ["NMDA", "mGluR5", "GABA"]
maps = ["cest", "NMDA", "mGluR5", "GABA"]
group = "HC"
n_permute = 10

# Set paths
inpath = "/Users/pecsok/Desktop/ImageData/PMACS_remote/data/nmaps/analyses/" + atlas
outpath = "/Users/pecsok/Desktop/ImageData/PMACS_remote/data/nmaps/analyses/" + atlas

# Read in data
receptor_df = pd.read_csv("/Users/pecsok/projects/Neuromaps/pecsok_pfns/neuromaps/results/receptor_data_scale1000_17.csv", sep=',')
means= pd.read_csv(inpath + '/means_subjectnmaps_' + dataset + '_' + atlas + '.csv', sep=',')
# Read in surface atlases
schaefer_cifti_path = '/Users/pecsok/neuromaps-data/atlases/fsLR/Schaefer2018_900Parcels_17Networks_order.dlabel.nii'
annot_rh_schaefer = '/Users/pecsok/neuromaps-data/atlases/fsaverage/lh.Schaefer2018_1000Parcels_17Networks_order.annot'
annot_lh_schaefer = '/Users/pecsok/neuromaps-data/atlases/fsaverage/rh.Schaefer2018_1000Parcels_17Networks_order.annot'
# Fetch atlas labels
schaefer = fetch_atlas_schaefer_2018(n_rois=1000, yeo_networks=17)
labels = schaefer['labels']  # These are the MNI coordinates of the parcels
labels = [label.decode('utf-8') for label in labels]

# Reformat data 
means["Parcel"] = means["Parcel"].str.replace(' NZMean', '', regex=False)
means = means[means['hstatus']== group]
means.set_index('Parcel', inplace = True)
receptor_df = receptor_df.reset_index(drop=True)
receptor_df.rename(columns={'GABAa': 'GABA'}, inplace=True)
receptor_df.index = pd.RangeIndex(start=1, stop=len(receptor_df) + 1, step=1)
print(means)
print(receptor_df)

                                 hstatus  CEST_avg      NMDA    mGluR5  \
Parcel                                                                   
17Networks_RH_ContA_Cingm_1           HC -0.239851 -0.515943 -0.290452   
17Networks_RH_ContA_Cingm_2           HC -0.255432 -0.662924 -0.352973   
17Networks_RH_ContB_PFCmp_1           HC -0.102085  0.665146  0.898630   
17Networks_RH_ContB_PFCmp_2           HC -0.779797  0.265863  0.639889   
17Networks_RH_ContB_PFCmp_3           HC -0.502373  0.326597  0.668602   
...                                  ...       ...       ...       ...   
17Networks_RH_VisPeri_ExStrSup_3      HC  0.613937  0.823610  1.587829   
17Networks_RH_VisPeri_ExStrSup_4      HC  0.733561  0.576520  0.683296   
17Networks_RH_VisPeri_ExStrSup_7      HC  0.403760  0.764860  1.305738   
17Networks_RH_VisPeri_StriCal_3       HC  1.342156 -0.013643  0.141399   
17Networks_RH_VisPeri_StriCal_5       HC  1.024538  0.177231  0.461898   

                                     

### Functions

In [117]:
def scramble(df, spin_df):
    """
    Reorders the rows of the input dataframe `df` based on the indices provided in `spin_df`.
    Each column of `spin_df` represents a permutation of row indices.
    Args:
    df (pd.DataFrame): The original dataframe with data.
    spin_df (pd.DataFrame): DataFrame where each column contains scrambled row indices (1-based).
    Returns:
    null_dfs (list): List of DataFrames with rows reordered according to scrambled indices.
    """
    null_dfs = []  # List to store the permuted DataFrames
    
    # Iterate over each permutation (column) in spin_df
    for col in spin_df.columns:   
        scrambled_indices = spin_df[col].values
        null_df = df.iloc[scrambled_indices] #.reset_index(drop=True)
        null_df.loc[:, 'Parcel'] = labels
        null_df.set_index('Parcel', inplace = True)
        # Append the permuted DataFrame to the list
        null_dfs.append(null_df)
    
    return null_dfs
###################

def calculate_null_corrs(null_dfs, nmaps):
    """
    Calculates correlation values for each null df and stores in output df
    Args:
    null_dfs (list): permutation matrices with new indices
    spin_df (pd.DataFrame): DataFrame where each column contains scrambled row indices (1-based).
    Returns:
    null_dfs (list): List of DataFrames with rows reordered according to scrambled indices.
    """
    # Initialize an empty dataframe to store the results
    columns = ['shift_iteration']
    for col in nmaps:
        columns.append(f'{col}_r')
        columns.append(f'{col}_p')
    
    null_stats_df = pd.DataFrame(columns=columns)
    
    # Loop through each null dataframe (each shift iteration)
    for i, null_df in enumerate(null_dfs):
        result_row = {'permutation': i + 1}
        
        # For each nmaps column, calculate r and p values with respect to CESTavg
        for col in nmaps:
            r_value, p_value = pearsonr(null_df['CEST_avg'], null_df[col])
            result_row[f'{col}_r'] = r_value
            result_row[f'{col}_p'] = p_value
        
        # Convert result_row to a DataFrame and concatenate it
        result_row_df = pd.DataFrame([result_row])
        null_stats_df = pd.concat([null_stats_df, result_row_df], ignore_index=True)
    
    return null_stats_df
########################

def calculate_real_correlations(df, nmaps):
    """
    This function calculates pearson correlations between CEST_avg and nmap columns.
    Args:
    df (pd.DataFrame): df with cest and nmap parcelwise data
    nmaps (list): names of nmaps
    Returns: 
    Correlation values and p values for parametric associations.
    """
    real_corrs = {}
    for col in nmaps:
        r_value, _ = pearsonr(df['CEST_avg'], df[col])
        real_corrs[col] = r_value
    return real_corrs
#####################

def calculate_empirical_pvalues(real_corrs, null_stats_df, nmaps):
    """
    This function compares real correlations to the null distribution to calculate empirical p-values
    Args:
    real_corrs (pd.DataFrame): df with r and p values for cest-nmap associations
    null_stats_df (pd.DataFrame): df with r and p value for cest-nmap associations per iteration
    nmaps (list): names of nmaps
    Returns: 
    Empirical p values based on null model
    """    
    p_values = {}
    
    # Loop through each column in nmaps
    for col in nmaps:
        # Extract the null r-values for this nmaps column
        null_r_values = null_stats_df[f'{col}_r']
        
        # Calculate the empirical p-value: proportion of null r-values more extreme than the real r-value
        real_r = real_corrs[col]
        
        # Two-tailed test: count both large positive and negative extremes
        extreme_count = np.sum(np.abs(null_r_values) >= np.abs(real_r))
        
        # Compute the empirical p-value
        p_empirical = extreme_count / len(null_r_values)
        p_values[col] = round(p_empirical, 15)
    
    return p_values


### Generate permutations

In [120]:
#### For some reason, shape = (999,n) for S1000 parcellation. All others have correct number of rows.

# Convert cifti to gifti
schaefer_gifti = images.dlabel_to_gifti(schaefer_cifti_path) 
# Spin indices
spin_index = nulls.vazquez_rodriguez(data=None, atlas='fsLR', density='32k',
                                     n_perm=10, seed=1234, parcellation=schaefer_gifti)
print(spin_index.shape) 

pixdim[1,2,3] should be non-zero; setting 0 dims to 1


(900, 10)


In [122]:
# Convert fsavg annotations to gifti
schaefer_fsgifti = images.annot_to_gifti((annot_lh_schaefer, annot_rh_schaefer))
# Spin indices
spin_index = nulls.vazquez_rodriguez(data=None, atlas='fsaverage', density='164k', 
                                     n_perm=10000, seed=1234, parcellation=schaefer_fsgifti)
# Reformat spun indices for permutation test
spin_df = pd.DataFrame(spin_index)
spin_df.index = pd.RangeIndex(start=1, stop=len(spin_df) + 1, step=1)
print(spin_df)

      0     1     2     3     4     5     6     7     8     9     ...  9990  \
1      309   145   495   167   460   408   209   211   306   323  ...   408   
2      304   146   496   493   459   406    99   210   307   356  ...   407   
3      385   144   491    22   461   407   189   206   310   320  ...   432   
4      385   147   496   165   459   407   207   191   310   323  ...   409   
5      446   323   167   365   263   395   256    98   318    13  ...    56   
...    ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   ...   
996    612   785   799   904   678   592   536   886   976   874  ...   666   
997    615   913   799   906   673   584   550   884   975   832  ...   667   
998    567   918   980   736   670   577   548   756   557   831  ...   805   
999    613   912   799   863   668   578   547   946   858   768  ...   809   
1000   685   916   777   735   667   571   545   944   562   743  ...   840   

      9991  9992  9993  9994  9995  9996  9997  999

In [125]:
pd.options.mode.chained_assignment = None 

In [123]:
# Generate list of null dfs
null_dfs = scramble(receptor_df, spin_df)
print(len(null_dfs))

10000


In [124]:
pd.options.mode.chained_assignment = 'warn'

In [114]:
# Keep only parcels in slice
means_cest = means[['hstatus', 'CEST_avg']]
merged_dfs = []
#print(means_cest)
for null_df in null_dfs:
    merged_df = means_cest.merge(null_df, on='Parcel', how='inner')
    merged_dfs.append(merged_df)
print(merged_dfs[0])

                                 hstatus  CEST_avg      NMDA    mGluR5  \
Parcel                                                                   
17Networks_RH_ContA_Cingm_1           HC -0.239851  0.849870 -0.210651   
17Networks_RH_ContA_Cingm_2           HC -0.255432 -1.850323 -0.732697   
17Networks_RH_ContB_PFCmp_1           HC -0.102085 -0.672774  0.373764   
17Networks_RH_ContB_PFCmp_2           HC -0.779797 -1.878063 -0.319689   
17Networks_RH_ContB_PFCmp_3           HC -0.502373 -0.857448  0.149048   
...                                  ...       ...       ...       ...   
17Networks_RH_VisPeri_ExStrSup_3      HC  0.613937 -0.369844  0.401474   
17Networks_RH_VisPeri_ExStrSup_4      HC  0.733561 -0.275643 -0.445693   
17Networks_RH_VisPeri_ExStrSup_7      HC  0.403760  1.614967  1.686016   
17Networks_RH_VisPeri_StriCal_3       HC  1.342156  0.367203  0.077556   
17Networks_RH_VisPeri_StriCal_5       HC  1.024538  1.029521  0.725854   

                                     

In [115]:
# Calculate stats of null distributions
null_stats_df = calculate_null_corrs(merged_dfs, nmaps)
print(null_stats_df)

  null_stats_df = pd.concat([null_stats_df, result_row_df], ignore_index=True)


    shift_iteration    NMDA_r    NMDA_p  mGluR5_r  mGluR5_p    GABA_r  \
0               NaN  0.170698  0.167253  0.033249  0.789389 -0.119807   
1               NaN -0.087050  0.483642  0.020436  0.869614 -0.006511   
2               NaN  0.264908  0.030279  0.247932  0.043081  0.206176   
3               NaN -0.024609  0.843304 -0.269574  0.027381  0.045665   
4               NaN -0.212196  0.084725 -0.268341  0.028122 -0.187891   
..              ...       ...       ...       ...       ...       ...   
995             NaN  0.202570  0.100177  0.058881  0.635994 -0.043174   
996             NaN -0.043952  0.723964  0.007469  0.952168 -0.143068   
997             NaN  0.060208  0.628392 -0.046377  0.709397  0.081111   
998             NaN -0.187085  0.129530 -0.122376  0.323863 -0.101366   
999             NaN  0.169433  0.170467  0.010760  0.931130  0.104330   

       GABA_p  permutation  
0    0.334200          1.0  
1    0.958296          2.0  
2    0.094150          3.0  
3    0.

In [119]:
real_corrs = calculate_real_correlations(means, nmaps)
empirical_p_values = calculate_empirical_pvalues(real_corrs, null_stats_df, nmaps)
print(real_corrs)
for col, p_val in empirical_p_values.items():
    print(f'{col}: {p_val:.15f}') 

{'NMDA': 0.503391785025024, 'mGluR5': 0.27255687743310303, 'GABA': 0.668868365303651}
NMDA: 0.000000000000000
mGluR5: 0.099000000000000
GABA: 0.000000000000000


In [None]:
# Now, input means df. 
# Rename parcels by parcel number and put rows in proper order.



receptor_df.index = labels
receptor_df.index.name = 'Parcel'

# Chop up receptor_df by map
NMDAmat = receptor_df[["NMDA"]]
GABAmat = receptor_df[["GABAa"]]
mGluR5mat = receptor_df[["mGluR5"]]

In [None]:
# CEST DATA
datapath = '/Users/pecsok/projects/Neuromaps/hansen_receptors/'
figpath = '/Users/pecsok/projects/GluCEST-fMRI/glucest-rsfmri/fmri_pipeline/parcellated_pipeline/figures'
cestavg_df = pd.DataFrame(index=range(1, 1001), columns=['CESTavg'])
cestavg_df['CESTavg'] = np.nan  # Initialize all values to 0
for _, row in cestdf.iterrows():
    # Extract parcel ID number using regex
    match = re.search(r'(\d+)', row['parcel'])
    if match:
        parcel_id = int(match.group(1))
        # Add CESTavg data to corresponding row in the new DataFrame
        if 1 <= parcel_id <= 1000:  # Ensure the parcel ID is within range
            cestavg_df.at[parcel_id, 'CESTavg'] = row['CESTavg']
            
for i, row in cestavg_df.iterrows():
    parcel_id = i
    value = row['CESTavg']
print(cestavg_df)

