In [1]:
import time
import numpy as np
import rsatoolbox as rsa
from numba import njit

In [2]:
# Load dataset
dsets = np.random.rand(18,1800,29696)
lh_rois = np.array([np.random.choice(29696,100,replace=False) for _ in range(1000)]) # Searchlight vertices

# Generate Word and presentation numbers
words = []; presentation = [];
for i in range(1800):
    words.append(int(i/6))
    # words.append(f'_{int(i/6)}_')
    presentation.append(np.mod(i,6))
words = np.array(words)
cv_descriptor = np.array(presentation)
measurements = dsets[1,:,lh_rois[1]].T

In [3]:
rdm_mask = np.triu(np.ones(len(np.unique(words)),dtype=bool),1).flatten()
@njit(nogil=True)
def fastCross(measurements,words,cv_descriptor,rdm_mask):
    n_chan = measurements.shape[1]
    noise = np.eye(n_chan)
    cv_folds = np.unique(cv_descriptor)
    unique_words = np.unique(words)
    rdm_len = int(len(unique_words)*(len(unique_words)-1)/2)
    rdms = np.zeros((len(cv_folds),rdm_len))
    for i,k in enumerate(cv_folds):
        inds  = cv_descriptor==k
        # Average by group
        mean_train = np.zeros((len(unique_words),n_chan))
        mean_test  = np.zeros((len(unique_words),n_chan))
        for j in range(len(unique_words)):
            tmp_train = measurements[inds & (words==unique_words[j])]
            tmp_test  = measurements[~inds & (words==unique_words[j])] 
            tmp_train = np.sum(tmp_train,0) / tmp_train.shape[0]
            tmp_test = np.sum(tmp_test,0) / tmp_train.shape[0]
            mean_train[j] = tmp_train
            mean_test[j]  = tmp_test
        # Calculate Crossnobis
        kernel = mean_train @ noise @ mean_test.T
        rdm = np.expand_dims(np.diag(kernel), 0) + np.expand_dims(np.diag(kernel), 1)\
            - kernel - kernel.T
        rdm2 = rdm.flatten()[rdm_mask] / np.double(n_chan)
        rdms[i] = rdm2
    # Average RDMS
    m_rdm = np.sum(rdms,0) / rdms.shape[0]
    return m_rdm
fastCross(measurements,words,cv_descriptor,rdm_mask)

array([ 3.68760544e-05,  7.75989883e-05, -4.00529531e-05, ...,
        2.10627683e-04,  3.61619696e-05, -1.90621169e-04])

In [4]:
# New
def sl_new(i):
    neighbors = lh_rois[i]
    rdms_data = []
    for j in range(len(dsets)): # Iterate over subjects
        sdata = fastCross(dsets[j][:,:29696][:,neighbors],words,cv_descriptor,rdm_mask)
        rdms_data.append(sdata)
    rdms_data = np.array(rdms_data)
    return rdms_data
# Old
ds = []; # Generate list of datasets
for i in range(len(dsets)):
    tmp    = dsets[i][:,:29696]
    tmp_ds = rsa.data.dataset.Dataset(measurements = tmp,
             descriptors = {'subj':i},
             obs_descriptors = {'words':words, 'presentation':presentation},
             channel_descriptors={'vertices':np.arange(29696)})
    ds.append(tmp_ds)
def sl_old(i):
    neighbors = lh_rois[i]
    rdms_data = []
    for j in range(len(ds)): # Iterate over subjects
        sdata = ds[j].subset_channel('vertices',neighbors)
        tmp = rsa.rdm.calc_rdm(sdata, descriptor='words',method='crossnobis',
                                cv_descriptor='presentation')
        rdms_data.append(tmp)
    rdms_data = rsa.rdm.rdms.concat(rdms_data)
    return rdms_data.dissimilarities

In [5]:
start = time.time()
tmp1 = [sl_new(i) for i in range(10)]
print(f'Elapsed Time:{time.time() - start:.1f} seconds')

Elapsed Time:4.6 seconds


In [6]:
start = time.time()
tmp2 = [sl_old(i) for i in range(10)]
print(f'Elapsed Time:{time.time() - start:.1f} seconds')

Elapsed Time:19.9 seconds


In [7]:
np.corrcoef(np.array(tmp1).flatten(),np.array(tmp2).flatten())

array([[1., 1.],
       [1., 1.]])