In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import numpy as np
import pandas as pd
import nibabel as nib
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
# better plots
sns.set_style("whitegrid")

In [3]:
SUBJECT = "sub-01"
SESSION = "perceptionTraining01"
DATA_PATH = f"data/ds001246/derivatives/preproc-spm/output/{SUBJECT}/ses-{SESSION}/func"
TARGETS_PATH = f"data/ds001246/{SUBJECT}/ses-{SESSION}/func"
TASK = "task-perception_run-01"
RUN = "01"

In [4]:
features = nib.load(f"{DATA_PATH}/{SUBJECT}_ses-{SESSION}_task-perception_run-{RUN}_bold_preproc.nii.gz").get_fdata()
targets = pd.read_csv(f"{TARGETS_PATH}/{SUBJECT}_ses-{SESSION}_task-perception_run-{RUN}_events.tsv", sep="\t")
target_categories = pd.read_csv("")

# features preprocessing
# targets.loc[0, "duration"] = 24

In [5]:
# preprocess features (remove resting states)
pp_features = features[:, :, :, 8:-2]

# preprocess targets (remove resting states)
pp_targets = targets[targets["event_type"] != "rest"]

In [6]:
pp_targets

Unnamed: 0,onset,duration,trial_no,event_type,stim_id,response_time
1,33,9,2,stimulus,2766534.0,0.0
2,42,9,3,stimulus,1970164.0,0.0
3,51,9,4,stimulus,4376876.0,0.0
4,60,9,5,stimulus,4225987.0,0.0
5,69,9,6,stimulus,3079230.0,0.0
6,78,9,7,stimulus,3079230.0,79.03421
7,87,9,8,stimulus,3394916.0,0.0
8,96,9,9,stimulus,1877134.0,0.0
9,105,9,10,stimulus,3924679.0,0.0
10,114,9,11,stimulus,3602883.0,0.0


In [6]:
# make dataset
data_features = torch.tensor(pp_features.reshape(55, 3, 50, 64, 64), dtype=torch.float32)
data_targets = torch.tensor(pp_targets["category_index"].to_numpy())

KeyError: 'category_index'

In [None]:
# 3D CNN model
class CNN3D(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.C1 = self._get_conv_layer(3, 16, (3, 5, 5))
        self.C2 = self._get_conv_layer(16, 16, (3, 5, 5))
        self.S1 = nn.MaxPool3d((3, 3, 3), stride=(2, 2, 2))
        self.C3 = self._get_conv_layer(16, 32, (3, 5, 5))
        self.C4 = self._get_conv_layer(32, 64, (3, 5, 5))
        self.S2 = nn.MaxPool3d((3, 3, 3), stride=(2, 2, 2))
        self.C5 = self._get_conv_layer(64, 64, (3, 5, 5))
        self.C6 = self._get_conv_layer(64, 8, (3, 3, 3))
        self.S3 = nn.MaxPool3d((3, 3, 3), stride=(2, 2, 2))

        self.FC = nn.Linear(8, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def _get_conv_layer(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.C1(x)
        x = self.C2(x)
        x = self.S1(x)
        x = self.C3(x)
        x = self.C4(x)
        x = self.S2(x)
        x = self.C5(x)
        x = self.C6(x)
        x = self.S3(x)

        x = x.view(x.size(0), -1)

        x = self.FC(x)
        x = self.relu(x)
        x = self.softmax(x)

        return x

In [None]:
model = CNN3D(num_classes=150)

In [None]:
sample_preds = model(data_features)

In [None]:
sample_preds.shape

torch.Size([55, 150])