In [None]:
%cd /content/drive/MyDrive/NMA_NeuroAI

In [None]:
!pip install nilearn rsatoolbox

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from nilearn.image import new_img_like
import nibabel as nib
import seaborn as sns
from nilearn import plotting
from rsatoolbox.inference import eval_fixed
from rsatoolbox.model import ModelFixed
from glob import glob
from rsatoolbox.util.searchlight import get_volume_searchlight, get_searchlight_RDMs, evaluate_models_searchlight
from rsatoolbox.rdm import RDMs
from rsatoolbox.vis import plot_model_comparison

In [None]:
def upper_tri(RDM):
    """upper_tri returns the upper triangular index of an RDM

    Args:
        RDM 2Darray: squareform RDM

    Returns:
        1D array: upper triangular vector of the RDM
    """
    # returns the upper triangle
    m = RDM.shape[0]
    r, c = np.triu_indices(m, 1)
    return RDM[r, c]


def classwise_rdm(rdm_300x300, n_classes=10, n_per_class=30):
    rdm_10x10 = np.zeros((n_classes, n_classes))
    for i in range(n_classes):
        for j in range(n_classes):
            # Get the sample indices for class i and class j
            idx_i = slice(i * n_per_class, (i + 1) * n_per_class)
            idx_j = slice(j * n_per_class, (j + 1) * n_per_class)
            block = rdm_300x300[idx_i, idx_j]
            rdm_10x10[i, j] = block.mean()
    return rdm_10x10

In [None]:
neural_data = np.load('fMRI/Sample_Neural_RDM.npz', allow_pickle=True)

fmri_rdm = neural_data['rdm']
info_list = neural_data['info_list']

digit_rdm = fmri_rdm[:, :10, :10]

In [None]:
cnn_sample_rdms = np.load('model_RDMs/rdms_by_layer.npz')
cnn_sample_rdms.files

In [None]:
cnn_rdms = {k: classwise_rdm(cnn_sample_rdms[k]) for k in cnn_sample_rdms.files}

In [None]:
cnn_models = []
for i_model in cnn_sample_rdms.files:
    m = ModelFixed(i_model, cnn_rdms[i_model])
    cnn_models.append(m)

print('created the following models:')
for i in range(len(cnn_models)):
    print(cnn_models[i].name)

In [None]:
centers_3d = np.array([[info_list[i][1][0], info_list[i][1][1], info_list[i][1][2]] for i in range(info_list.shape[0])])
centers = np.ravel_multi_index(centers_3d.T, (64,76,64))

In [None]:
digit_rdm.shape

In [None]:
cnn_models[-1].rdm.

In [None]:
searchlight_RDMs = RDMs(
    dissimilarities=digit_rdm,
    dissimilarity_measure='1-corr',
    rdm_descriptors={'voxel_index': centers}
)

In [None]:
searchlight_RDMs.n_rdm

In [None]:
eval_results = evaluate_models_searchlight(searchlight_RDMs, cnn_models, eval_fixed, method='spearman', n_jobs=3)

In [None]:
# get the evaulation score for each voxel
# We only have one model, but evaluations returns a list. By using float we just grab the value within that list
eval_score = np.array([e.evaluations.flatten() for e in eval_results])
eval_score.shape

In [None]:
# Create an 3D array, with the size of mask, and
x, y, z = (64, 76, 64)
RDM_brain = np.zeros([eval_score.shape[1], x*y*z])
for i in range(eval_score.shape[1]):
    RDM_brain[i][list(searchlight_RDMs.rdm_descriptors['voxel_index'])] = eval_score[:, i]
RDM_brain = RDM_brain.reshape([eval_score.shape[1], x, y, z])

In [None]:
sns.distplot(eval_score)
plt.title('Distributions of correlations', size=18)
plt.ylabel('Occurance', size=18)
plt.xlabel('Spearmann correlation', size=18)
sns.despine()
plt.show()

In [None]:
tmp_img = nib.load('fMRI/Digit.nii')

In [None]:
import matplotlib.colors
def RDMcolormapObject(direction=1):
    """
    Returns a matplotlib color map object for RSA and brain plotting
    """
    if direction == 0:
        cs = ['yellow', 'red', 'gray', 'turquoise', 'blue']
    elif direction == 1:
        cs = ['blue', 'turquoise', 'gray', 'red', 'yellow']
    else:
        raise ValueError('Direction needs to be 0 or 1')
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", cs)
    return cmap


In [None]:
# lets plot the voxels above the 99th percentile
for i, layer_name in enumerate(cnn_sample_rdms.files):
    plot_img = new_img_like(tmp_img, RDM_brain[i])
    threshold = np.percentile(eval_score[i], 99)
    cmap = RDMcolormapObject()

    coords = range(-20, 40, 5)
    fig = plt.figure(figsize=(12, 3))

    display = plotting.plot_stat_map(
            plot_img, colorbar=True, cut_coords=coords, threshold=threshold,
            display_mode='z', draw_cross=False, figure=fig,
            title=f'CNN ({layer_name})', cmap=cmap,
            black_bg=False, annotate=False)
    plt.show()