## Class NSD_Dataset

In [None]:
from torch.utils.data import Dataset
from src.file_utility import load_mask_from_nii, view_data
from src.file_utility import save_stuff, flatten_dict, embed_dict
import h5py
import numpy as np
from scipy.io import loadmat
class NSD_Dataset(Dataset):
    '''
    Output: 
        1. 'voxel_mask', 'voxel_roi', 'voxel_idx', 'voxel_data' --> decoder--> reconstucted feature
        2. stimuli image --> encoder --> feature
        3. annotations for image (need further processing)
    '''
    def __init__(self, subject=1, fMRI_transform=None, image_transform = None, view_3d = False,mode='train',\
                rootpath='/scratch/cl6707/Projects/neuro_interp/data/NSD/'):
        self.meta_data = {}
        self.brain_nii_shape = {}
        self.rootpath = rootpath
        self.subject = subject
        self.fMRI_transform = fMRI_transform
        self.image_transform = image_transform
        self.view_3d = view_3d
        nsd_root = "/scratch/cl6707/Projects/neuro_interp/data/NSD/"
        stim_root = nsd_root + "nsddata_stimuli/stimuli/nsd/"
        beta_root = nsd_root + "nsddata_betas/ppdata/"
        mask_root = nsd_root + "nsddata/ppdata/"
        
        # single subject
        voxel_roi_full  = load_mask_from_nii(mask_root + "subj%02d/func1pt8mm/roi/prf-visualrois.nii.gz"%subject)
        self.brain_nii_shape = voxel_roi_full.shape
        voxel_data_set = h5py.File(rootpath+f'voxel_data_general_part{subject}.h5py', 'r')
        voxel_data_dict = embed_dict({k: np.copy(d) for k,d in voxel_data_set.items()})
        for k,v in voxel_data_dict.items():
            if k not in self.meta_data.keys():
                self.meta_data[k] = v[str(subject)]
            else:
                self.meta_data[k] = np.concatenate((self.meta_data[k],v),axis=0)
        voxel_data_set.close()

        exp_design_file = nsd_root + "nsddata/experiments/nsd/nsd_expdesign.mat"
        exp_design = loadmat(exp_design_file)
        self.ordering = exp_design['masterordering'].flatten() - 1 # zero-indexed ordering of indices (matlab-like to python-like)
        # load stimuli
        image_data_set = h5py.File(stim_root + "S%d_stimuli_227.h5py"%subject, 'r')
        self.stim_data = np.copy(image_data_set['stimuli'])
        image_data_set.close()
    def __len__(self):
        return len(self.meta_data['voxel_data'])

    def __getitem__(self, idx):
        !#todo: Annotation/categories for stimuli
        stim = self.stim_data[idx]
        voxel_data = self.meta_data['voxel_data'][idx]
        voxel_mask = self.meta_data['voxel_mask']
        voxel_roi = self.meta_data['voxel_roi']
        voxel_idx = self.meta_data['voxel_idx']

        stim_idx = self.ordering[idx]
        stim = self.stim_data[stim_idx]

        # Apply transformation to voxel data and stimuli

        if self.view_3d:
            volume_voxel = np.nan_to_num(view_data(self.brain_nii_shape, voxel_idx, voxel_data ))

        if self.fMRI_transform != None:
            voxel_data = self.fMRI_transform(voxel_data)
        
        if self.image_transform != None:
            stim = self.image_transform(stim)
            
        return_dict = {
            'stim': stim,
            'voxel_data': volume_voxel if self.view_3d else voxel_data,
            # 'voxel_mask': voxel_mask,
            'voxel_roi': voxel_roi,
            'voxel_idx': voxel_idx,
            'stim_idx': stim_idx,
        }

        return return_dict

In [None]:
dataset = NSD_Dataset(subject=1,view_3d=True)
sample = dataset[0]
for k,v in sample.items():
    print (k,v.shape)

#todo: Data Augmentation for fMRI data and image data?