In [1]:
%cd /home/yuchen/mind-vis/code

/home/yuchen/mind-vis/code


In [2]:
from dataset import create_Kamitani_dataset, create_BOLD5000_dataset
from os.path import join as pjoin
import glob
import numpy as np
import pandas as pd
from nilearn.masking import apply_mask, unmask
from nilearn.plotting import plot_epi, plot_stat_map
from nilearn.image import load_img, index_img, iter_img
import matplotlib.pyplot as plt
import cortex
from PIL import Image
import json
import csv
import os
import torch
from torch.utils.data import Dataset

In [3]:
train,test = create_Kamitani_dataset()

In [4]:
test.fmri.shape

(250, 4656)

In [4]:
_,test = create_Kamitani_dataset()

load test: 249 out of 250       
train_fmri
train_img
train_img_label_all
(3198, 1200)   
(3198, 1200)
constructing dataset

In [3]:
def identity(x):
    return x

def pad_to_patch_size(x, patch_size):
    assert x.ndim == 2
    return np.pad(x, ((0,0),(0, patch_size-x.shape[1]%patch_size)), 'wrap')

def normalize(x, mean=None, std=None):
    mean = np.mean(x) if mean is None else mean
    std = np.std(x) if std is None else std
    return (x - mean) / (std * 1.0)

class Kamitani_dataset(Dataset):
    def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transform=identity, num_voxels=0, num_per_sub=50):
        super(Kamitani_dataset, self).__init__()
        self.fmri = fmri
        self.image = image
        if len(self.image) != len(self.fmri):
            self.image = np.repeat(self.image, 35, axis=0)
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
        self.num_per_sub = num_per_sub
        self.img_class = [i[0] for i in img_label]
        self.img_class_name = [i[1] for i in img_label]
        self.naive_label = [i[2] for i in img_label]
        self.return_image_class_info = False

    def __len__(self):
        return len(self.fmri)
    
    def __getitem__(self, index):
        fmri = self.fmri[index]
        if index >= len(self.image):
            img = np.zeros_like(self.image[0])
        else:
            img = self.image[index] / 255.0
        fmri = np.expand_dims(fmri, axis=0) # (1, num_voxels)
        if self.return_image_class_info:
            img_class = self.img_class[index]
            img_class_name = self.img_class_name[index]
            naive_label = torch.tensor(self.naive_label[index])
            return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img),
                    'image_class': img_class, 'image_class_name': img_class_name, 'naive_label':naive_label}
        else:
            return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img)}
            
def create_Kamitani_dataset(path='../data/Kamitani/npz',  roi='VC', patch_size=16, fmri_transform=identity,
            image_transform=identity, subjects = ['sbj_1', 'sbj_2', 'sbj_3', 'sbj_4', 'sbj_5'], 
            test_category=None, include_nonavg_test=False):
    basedir = '/home/yuchen/dataset/fmri'
    betas_csv_dir = pjoin(basedir, 'betas_csv')
    sub = '01'
    
    data_file = pjoin(betas_csv_dir, f'sub-{sub}_ResponseData.h5')
    responses = pd.read_hdf(data_file)
    vox_f = pjoin(betas_csv_dir, f'sub-{sub}_VoxelMetadata.csv')
    voxdata = pd.read_csv(vox_f)
    stim_f = pjoin(betas_csv_dir, f'sub-{sub}_StimulusMetadata.csv')
    stimdata = pd.read_csv(stim_f)

    train_idx = stimdata[stimdata['trial_type'] == 'train'].index.tolist()
    test_idx = stimdata[stimdata['trial_type'] == 'test'].index.tolist()
    
    train_fmri = responses[train_idx]
    test_fmri = responses[test_idx]
    
    train_labels = stimdata[stimdata['trial_type'] == 'train']['stimulus'].iloc[:6000]
    test_labels = stimdata[stimdata['trial_type'] == 'test']['stimulus'].iloc[:250]
    
    roi_idx = voxdata[(voxdata['V1'] == 1) | (voxdata['V2'] == 1) | (voxdata['V3'] == 1) | (voxdata['hV4'] == 1) ]['voxel_id'].tolist()
    train_fmri = train_fmri.iloc[roi_idx]
    test_fmri = test_fmri.iloc[roi_idx]

    del responses, voxdata, stimdata
    
    train_img, test_img = np.array([]),np.array([])

    train_img = []  # We will collect arrays here and then concatenate them at the end
    first_img_path = os.path.join('/home/yuchen/dataset/images_resized/', train_labels.iloc[0])
    with Image.open(first_img_path) as first_img:
        first_img_array = np.array(first_img)
        img_shape = first_img_array.shape
        
    train_img = np.empty((len(train_labels), *img_shape), dtype=first_img_array.dtype)
    for i, label in enumerate(train_labels):
        print(f'load train: {i} out of {len(train_labels)}    ', end='\r')
        img_path = os.path.join('/home/yuchen/dataset/images_resized/', label)
        with Image.open(img_path) as img:
            train_img[i] = np.array(img)

    test_img = []  # We will collect arrays here and then concatenate them at the end
    first_img_path = os.path.join('/home/yuchen/dataset/images_resized/', test_labels.iloc[0])
    with Image.open(first_img_path) as first_img:
        first_img_array = np.array(first_img)
        img_shape = first_img_array.shape
        
    test_img = np.empty((len(test_labels), *img_shape), dtype=first_img_array.dtype)
    for i, label in enumerate(test_labels):
        print(f'load test: {i} out of {len(test_labels)}    ', end='\r')
        img_path = os.path.join('/home/yuchen/dataset/images_resized/', label)
        with Image.open(img_path) as img:
            test_img[i] = np.array(img)
    print()
    train_fmri, test_fmri =  train_fmri.to_numpy(), test_fmri.to_numpy()
    train_img_label_all, test_img_label_all = train_labels.tolist(), test_labels.tolist()
    num_voxels = train_fmri.shape[-1]

    print('normalizing    ', end='\r')
    train_fmri = normalize(pad_to_patch_size(train_fmri, patch_size))
    test_fmri = normalize(pad_to_patch_size(test_fmri, patch_size), np.mean(train_fmri), np.std(train_fmri))
    
    
    print('constructing dataset', end = '\r')
    if isinstance(image_transform, list):
        return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform[0], num_voxels, len(train_img_label_all)//5), 
                Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform[1], num_voxels, len(test_img_label_all)//5))
    else:
        return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform, num_voxels, len(train_img_label_all)//5), 
                Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform, num_voxels, len(test_img_label_all)//5))

In [22]:
class Kamitani_dataset(Dataset):
    def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transform=identity, num_voxels=0, num_per_sub=50):
        super(Kamitani_dataset, self).__init__()
        self.fmri = fmri
        self.image = image
        if len(self.image) != len(self.fmri):
            self.image = np.repeat(self.image, 35, axis=0)
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
        self.num_per_sub = num_per_sub
        self.img_class = [i[0] for i in img_label]
        self.img_class_name = [i[1] for i in img_label]
        self.naive_label = [i[2] for i in img_label]
        self.return_image_class_info = False

    def __len__(self):
        return len(self.fmri)
    
    def __getitem__(self, index):
        fmri = self.fmri[index]
        if index >= len(self.image):
            img = np.zeros_like(self.image[0])
        else:
            img = self.image[index] / 255.0
        fmri = np.expand_dims(fmri, axis=0) # (1, num_voxels)
        if self.return_image_class_info:
            img_class = self.img_class[index]
            img_class_name = self.img_class_name[index]
            naive_label = torch.tensor(self.naive_label[index])
            return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img),
                    'image_class': img_class, 'image_class_name': img_class_name, 'naive_label':naive_label}
        else:
            return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img)}