### Setup Install

In [None]:
!pip install nibabel 
!pip install glob2

### Setup Import

In [None]:
import nibabel as nib
import numpy as np
import glob
import matplotlib.pyplot as plt
import torch 
from torch.utils import data as torch_data

### Data Pre-processing
Pairing & Classification Datasets 

In [None]:
paths = './training_data'
path = glob.glob(paths+'/*')
arrays = []
segmen = []
types = ["flair", "t1", "t1ce","t2"]

for a in path:
    trainData = glob.glob(a+'/*') 
    for modality in trainData:
        strsplit = modality.split('_')
        typefiles = strsplit[4].split('.')[0]
        if typefiles in types:
            nifti_file = nib.load(modality)
            brain_affine = nifti_file.affine # 
            brain_numpy = np.asarray(nifti_file.dataobj)
            dataPad = np.pad(brain_numpy, ((8,8), (8,8), (50,51)), 'constant')
            dataPad = dataPad/np.max(dataPad)
            arrays.append(dataPad)
        else:
            nifti_file = nib.load(modality)
            brain_affine = nifti_file.affine
            brain_numpy = np.asarray(nifti_file.dataobj)
            dataPad = np.pad(brain_numpy, ((8,8), (8,8), (50,51)), 'constant')
            dataPad = dataPad/np.max(dataPad)
            segmen.append(dataPad)

    newDataArrays = nib.Nifti1Image(np.asarray(arrays), nifti_file.affine, nifti_file.header)
    nib.save(newDataArrays, './train/BraTS2021_'+ a.split('_')[2] + '.nii')

    newDataSegmen = nib.Nifti1Image(np.asarray(segmen), nifti_file.affine, nifti_file.header)
    nib.save(newDataSegmen, './label/BraTS2021_'+ a.split('_')[2] + '.nii')

### Visualize Test 

In [None]:
dataTest = './training_data/BraTS2021_00495/BraTS2021_00495_t1.nii.gz'
img1 = nib.load(dataTest)
np.shape(img1.dataobj)
plt.imshow(img1.dataobj[:,:,145])

In [None]:
dataTest = './train/BraTS2021_00495.nii'
img2 = nib.load(dataTest)
np.shape(img2.dataobj)
plt.imshow(img2.dataobj[1,:,:,195])

### Dictionary

In [None]:
image_train = sorted(glob.glob('./train/*'))
label_train = sorted(glob.glob('./label/*'))

trainDict = [
    {
        "images": image_trains,
        "label": label_trains,
    } for image_trains, label_trains in zip(image_train, label_train)]

### Datasets Class

In [None]:
class DatasetsMRI(torch_data.Dataset):
    def __init__(self, data_root, transform=None, target_transform=None):
        super(DatasetsMRI, self).__init__()
        self.data_root = data_root
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.data_root)
    
    def __getitem__(self, index:int) -> tuple:
        print(self.data_root[0])
        
#     return image_train1, labels_train1

In [None]:
testDatasets = DatasetsMRI(trainDict)

### Mixed Code (unpairing)

In [None]:
normalized = n2_img.get_fdata().astype(np.float32) / n2_img.get_fdata().max()
print(normalized.shape)
print(np.max(normalized))

In [None]:
# pick one image from DecathlonDataset to visualize and check the 4 channels
val_data_example = val_ds[2]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example["image"][i, :, :, 60].detach().cpu(), cmap="gray")
plt.show()
# also visualize the 3 channels label corresponding to this image
print(f"label shape: {val_data_example['label'].shape}")
plt.figure("label", (18, 6))
for i in range(3):
    plt.subplot(1, 3, i + 1)
    plt.title(f"label channel {i}")
    plt.imshow(val_data_example["label"][i, :, :, 60].detach().cpu())
plt.show()