In [1]:
import os

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from modules.Utils import get_file_names

In [2]:
class FeTADataSet(Dataset):
    # def __init__(self, quality=[], labels=[], pathologies=[], load_3d=None):
    def __init__(self, path="feta_2.1", train=True, transform=None, pathology="all"):
        """"""

        count_train = 70  # First 70 MRI image consist of 40 Pathological and 20 Neurotypical.
        self.__path_base = path
        self.__train = train
        self.__transform = transform

        self.meta_data = pd.read_csv(os.path.join(self.__path_base, "participants.tsv"), sep="\t")
        self.__paths_file = get_file_names(self.__path_base)

        # Images below might have bad qualities
        # self.meta_data = self.meta_data.drop(index=self.meta_data[
        # self.meta_data["participant_id"]=="sub-007"
        # ].index)
        # self.meta_data = self.meta_data.drop(index=self.meta_data[
        # self.meta_data["participant_id"]=="sub-009"
        # ].index)

        if pathology == "Pathological":
            self.meta_data = self.meta_data[self.meta_data.Pathology == "Pathological"]
        elif pathology == "Neurotypical":
            self.meta_data = self.meta_data[self.meta_data.Pathology == "Neurotypical"]
        else:
            # Return data for training or test.
            if self.__train:
                self.meta_data = self.meta_data[:count_train]
            else:
                self.meta_data = self.meta_data[count_train:]
                self.meta_data = self.meta_data.reset_index().drop("index", axis=1)


    def __getitem__(self, index):
        """"""

        data = self.__paths_file[self.meta_data.participant_id[index]]
        path_image, path_mask = data[0], data[1]

        mri_image = nib.load(path_image).get_fdata()
        mri_mask = nib.load(path_mask).get_fdata()

        if self.__transform:
            mri_image = torch.tensor(mri_image)
            mri_image = mri_image.view(1, 256, 256, 256)
            mri_image = self.__transform(mri_image)
            mri_image = mri_image.view(256, 256, 256)

        return mri_image, mri_mask

    def __len__(self):
        return self.meta_data.shape[0]

In [3]:
train_dataset = FeTADataSet(train=True)
test_dataset = FeTADataSet(train=False)

train_loader = DataLoader(dataset=train_dataset,
                         batch_size=4,
                         shuffle=True,
                         num_workers=2)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=2,
                         shuffle=True,
                         num_workers=2)

In [None]:
images, masks = train_dataset[0][0], train_dataset[0][1]

In [None]:
print(images[128, 128, :])

In [None]:
print(masks[128, 128, :])

In [None]:
meta_data = pd.read_csv(os.path.join("../Code/feta_2.1/participants.tsv"), sep="\t")

In [None]:
import torchio as tio
znorm_transform = tio.ZNormalization(masking_method=tio.ZNormalization.mean)

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train_dataset_transform = FeTADataSet(train=True, transform=transform_)


In [None]:
images_t, masks_t = train_dataset_transform[0][0], train_dataset_transform[0][1]

In [None]:
print(images_t[128, 128, :].shape)

In [None]:
print(masks_t[128, 128, :].shape)

In [2]:
meta_data = pd.read_csv(os.path.join("../Code/feta_2.1/participants.tsv"), sep="\t")
meta_data

Unnamed: 0,participant_id,Pathology,Gestational age
0,sub-001,Pathological,27.9
1,sub-002,Pathological,28.2
2,sub-003,Pathological,27.4
3,sub-004,Pathological,25.5
4,sub-005,Pathological,22.6
...,...,...,...
75,sub-076,Neurotypical,23.2
76,sub-077,Pathological,26.9
77,sub-078,Pathological,24.0
78,sub-079,Neurotypical,29.1


In [41]:
mial_srtk = meta_data[:40]
simple_irtk = meta_data[40:]

In [67]:
index1 = mial_srtk[(mial_srtk["Gestational age"]<=28) & (mial_srtk["Pathology"]=="Neurotypical")].index.to_list()
index2 = mial_srtk[(mial_srtk["Gestational age"]<=28) & (mial_srtk["Pathology"]=="Pathological")].index.to_list()

index3 = mial_srtk[(mial_srtk["Gestational age"]>28) & (mial_srtk["Pathology"]=="Neurotypical")].index.to_list()
index4 = mial_srtk[(mial_srtk["Gestational age"]>28) & (mial_srtk["Pathology"]=="Pathological")].index.to_list()



index5 = simple_irtk[(simple_irtk["Gestational age"]<=28) & (simple_irtk["Pathology"]=="Neurotypical")].index.to_list()
index6 = simple_irtk[(simple_irtk["Gestational age"]<=28) & (simple_irtk["Pathology"]=="Pathological")].index.to_list()

index7 = simple_irtk[(simple_irtk["Gestational age"]>28) & (simple_irtk["Pathology"]=="Neurotypical")].index.to_list()
index8 = simple_irtk[(simple_irtk["Gestational age"]>28) & (simple_irtk["Pathology"]=="Pathological")].index.to_list()

In [105]:
train, validation, test = [], [], []

train.append(index1[:5])
train.append(index2[:16])
train.append(index3[:6])
train.append(index4[:3])
train.append(index5[:5])
train.append(index6[:13])
train.append(index7[:9])
train.append(index8[:3])

train = [item for sub_arr in train for item in sub_arr]

validation.append(index1[5:6])
validation.append(index2[16:18])
validation.append(index3[6:7])
validation.append(index4[3:4])
validation.append(index5[5:6])
validation.append(index6[13:15])
validation.append(index7[9:10])
validation.append(index8[3:4])

validation = [item for sub_arr in validation for item in sub_arr]

test.append(index1[6:])
test.append(index2[18:])
test.append(index3[7:])
test.append(index4[4:])
test.append(index5[6:])
test.append(index6[15:])
test.append(index7[10:])
test.append(index8[4:])

test = [item for sub_arr in test for item in sub_arr]

In [34]:
### Dataset Information ###
""" 
There are 80 MRI images of 80 subjects. Gestational ages of subjects ranges 20 weeks to 35 weeks. 
There are Pathological and Neurotypical subjects.
First 40 MRI images (sub-001 - sub-040) constructed by mialSRTK method. 
Other 40 MRI images (sub-041 - sub-080) constructed by simpleIRTK method.

mialSRTK reconstruction:
    * Gestational age <=28 (28 choosed intuitively for diversity gestational weeks and smoother age disturbition)
        - Neurotypical: 7 MRI images.   [train:5, val:1, test:1]
        - Pathological: 20 MRI images.  [train:16, val:2, test:2]
    * Gestational age > 28
        - Neurotypical: 8 MRI images.   [train:6, val:1, test:1]
        - Pathological: 5 MRI images.   [train:3, val:1, test:1]
        
simpleIRTK reconstruction:
    * Gestational age <=28
        - Neurotypical: 7 MRI images.   [train:5, val:1, test:1]
        - Pathological: 17 MRI images.  [train:13, val:2, test:2]
    * Gestational age > 28
        - Neurotypical: 11 MRI images.  [train:9, val:1, test:1]
        - Pathological: 5 MRI images.   [train:3, val:1, test:1]      
"""



' \nThere are 80 MRI images of 80 subjects. Gestational ages of subjects ranges 20 weeks to 35 weeks. \nThere are Pathological and Neurotypical subjects.\nFirst 40 MRI images (sub-001 - sub-040) constructed by mialSRTK method. \nOther 40 MRI images (sub-041 - sub-080) constructed by simpleIRTK method.\n\nmialSRTK reconstruction:\n    * Gestational age <=28 (28 choosed intuitively for diversity gestational weeks and smoother age disturbition)\n        - Neurotypical: 7 MRI images.   [train:5, val:1, test:1]\n        - Pathological: 20 MRI images.  [train:16, val:2, test:2]\n    * Gestational age > 28\n        - Neurotypical: 8 MRI images.   [train:6, val:1, test:1]\n        - Pathological: 5 MRI images.   [train:3, val:1, test:1]\n        \nsimpleIRTK reconstruction:\n    * Gestational age <=28\n        - Neurotypical: 7 MRI images.   [train:5, val:1, test:1]\n        - Pathological: 17 MRI images.  [train:13, val:2, test:2]\n    * Gestational age > 28\n        - Neurotypical: 11 MRI im