In [81]:
import glob
import numpy as np
import pandas as pd
import nibabel as nib
import seaborn as sns
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset

In [2]:
# better plots
sns.set_style("whitegrid")

In [3]:
# config
FEATURES_PATH = "data/ds001246/derivatives/preproc-spm/output/"
TARGETS_PATH = "data/ds001246/"
TRAIN_CATEGORIES_PATH = "data/ds001246/stimulus_ImageNetTraining.csv"
TEST_CATEGORIES_PATH = "data/ds001246/stimulus_ImageNetTest.csv"

In [102]:
class GODData(Dataset):
    FEATURES_PATH = "data/ds001246/derivatives/preproc-spm/output/"
    TARGETS_PATH = "data/ds001246/"
    TRAIN_CATEGORIES_PATH = "data/ds001246/stimulus_ImageNetTraining.csv"
    TEST_CATEGORIES_PATH = "data/ds001246/stimulus_ImageNetTest.csv"

    def __init__(self, subject="01", session_id="01", task="perception", train=True):
        session = f"{task}{'Training' if train else 'Test'}{session_id}"

        feature_runs = sorted(glob.glob(f"{self.FEATURES_PATH}/sub-{subject}/ses-{session}/func/*"))
        target_runs = sorted(glob.glob(f"{self.TARGETS_PATH}/sub-{subject}/ses-{session}/func/*events*"))

        features = []
        targets = []

        categories = pd.read_csv(self.TRAIN_CATEGORIES_PATH if train else self.TEST_CATEGORIES_PATH, sep="\t", header=None)

        for f_run, t_run in zip(feature_runs, target_runs):
            features_run = nib.load(f_run).get_fdata()
            targets_run = pd.read_csv(t_run, sep="\t")

            # remove resting states
            features_run_pp = features_run[:, :, :, 8:-2]
            targets_run_pp = targets_run[targets_run["event_type"] != "rest"]

            # reshape features into (N, C, D, W, H)
            features_run_pp = features_run_pp.reshape(-1, 3, 50, 64, 64)

            # extract category labels
            targets_run_pp = targets_run_pp.merge(categories, left_on="stim_id", right_on=1)[2]
            targets_run_pp = targets_run_pp.to_numpy().reshape(-1, 1)

            features.append(features_run_pp)
            targets.append(targets_run_pp)

        features = np.vstack(features)
        targets = np.vstack(targets)

        # convert and store as tensors
        self.features = torch.from_numpy(features).float()
        self.targets = torch.from_numpy(targets).long()

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        feature = self.features[index]
        target = self.targets[index]
        return feature, target

In [103]:
data = GODData(subject="01", session_id="01", task="perception", train=True)

2766534.015976
3665924.055985
4409515.002581
4086273.005433
3950228.020813
1855672.0159
3623556.012791
4254680.00174
4373894.026002
3063599.003234
