### Setup Install

In [1]:
!pip install nibabel 
!pip install glob2
!pip install monai



### Setup Import

In [2]:
import nibabel as nib
import numpy as np
import glob
import matplotlib.pyplot as plt
import torch 
from torch.utils import data as torch_data
import monai.transforms as transforms
from torch.utils.data import Dataset, DataLoader

  "class": algorithms.Blowfish,


### Data Pre-processing
Pairing & Classification Datasets 

In [None]:
paths = './training_data'
path = glob.glob(paths+'/*')
arrays = []
segmen = []
types = ["flair", "t1", "t1ce","t2"]
max_dir = 30

for idx, a in enumerate(path):
    if idx >= max_dir:
        break
    trainData = glob.glob(a+'/*') 
    for modality in trainData:
        strsplit = modality.split('_')
        typefiles = strsplit[4].split('.')[0]
        if typefiles in types:
            nifti_file = nib.load(modality)
            brain_affine = nifti_file.affine  
            brain_numpy = np.asarray(nifti_file.dataobj)
            dataPad = np.pad(brain_numpy, ((8,8), (8,8), (50,51)), 'constant')
            dataPad = dataPad/np.max(dataPad)
            arrays.append(dataPad)
        else:
            nifti_file = nib.load(modality)
            brain_affine = nifti_file.affine
            brain_numpy = np.asarray(nifti_file.dataobj)
            dataPad = np.pad(brain_numpy, ((8,8), (8,8), (50,51)), 'constant')
            dataPad = dataPad/np.max(dataPad)
            segmen.append(dataPad)

    newDataArrays = nib.Nifti1Image(np.asarray(arrays), nifti_file.affine, nifti_file.header)
    nib.save(newDataArrays, './train/BraTS2021_'+ a.split('_')[2] + '.nii')

    newDataSegmen = nib.Nifti1Image(np.asarray(segmen), nifti_file.affine, nifti_file.header)
    nib.save(newDataSegmen, './label/BraTS2021_'+ a.split('_')[2] + '.nii')

### Visualize Test 

In [None]:
dataTest = './training_data/BraTS2021_00495/BraTS2021_00495_t1.nii.gz'
img1 = nib.load(dataTest)
np.shape(img1.dataobj)
plt.imshow(img1.dataobj[:,:,145])

In [None]:
dataTest = './train/BraTS2021_00495.nii'
img2 = nib.load(dataTest)
np.shape(img2.dataobj)
plt.imshow(img2.dataobj[1,:,:,195])

### Data Dictionary

In [None]:
image_train = sorted(glob.glob('./train/*'))
label_train = sorted(glob.glob('./label/*'))

trainDict = [
    {
        "images": image_trains,
        "label": label_trains,
    } for image_trains, label_trains in zip(image_train, label_train)]

In [None]:
print(trainDict)

### Datasets Class

In [None]:
def LoadNifti(path):
    img_train = nib.load(path)
    all_img = img_train.affine  
    brain_numpy = np.asarray(img_train.dataobj)
    return brain_numpy

In [None]:
class DatasetsMRI(torch_data.Dataset):
    def __init__(self, data_root, transform=None):
        super(DatasetsMRI, self).__init__()
        self.data_root = data_root
        self.transform = transform
        
    def __len__(self):
        return len(self.data_root)
    
    def __getitem__(self, index:int) -> tuple:
        images_label = self.data_root[index]
        img1 = LoadNifti(images_label['images'])
        img2 = LoadNifti(images_label['label'])
        itemDict = ({
               "images": img1,
                "label": img2 
        }) 
        if self.transform:
            itemDict = self.transform({'images':img1, 'label':img2})
        return itemDict['images'], itemDict['label'], index

In [None]:
testDatasets = DatasetsMRI(trainDict)
trainDatasets = DatasetsMRI(trainDict)

### Class Test

In [None]:
img,lbl,idx = testDatasets[0]
print(np.shape(img))
plt.imshow(img[0, :, :, 100])

### Data Loader

In [None]:
def train_transform():
    
    data_aug = [
        # crop
        

        # spatial aug
#         transforms.RandFlipd(keys="images", prob=1, spatial_axis=0),
#         transforms.RandFlipd(keys="images", prob=1, spatial_axis=1),
#         transforms.RandFlipd(keys="images", prob=1, spatial_axis=2),

        # intensity aug
        #transforms.RandGaussianNoised(keys='image', prob=0.15, mean=0.0, std=0.2),
        transforms.RandGaussianSmoothd(
            keys='images', prob=0.3, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
        #transforms.RandAdjustContrastd(keys='image', prob=0.15, gamma=(0.7, 1.3)),

        # other stuff
        transforms.EnsureTyped(keys=["images", 'label']),
    ]
    return transforms.Compose(data_aug)
def test_transform():
    
    infer_transform = [transforms.EnsureTyped(keys=["images", 'label'])]
    return transforms.Compose(infer_transform)

In [None]:
def get_train_loader(case_names):
    train_transforms = train_transform()
    train_dataset = DatasetsMRI(
        data_root=case_names, 
        transform=train_transforms)

    return DataLoader(train_dataset, batch_size=1, shuffle=True, 
                       num_workers=1, pin_memory=True)

def get_test_loader(case_names):
    test_transforms = test_transform()
    test_dataset = DatasetsMRI(
        data_root=case_names, 
        transform=test_transforms)

    return DataLoader(test_dataset, batch_size=1, shuffle=False, 
                    num_workers=1, pin_memory=True)

In [None]:
dataset_size = len(trainDict)
test_size = int(0.3 * dataset_size)
train_size = dataset_size - test_size
train_dataset , test_dataset = torch.utils.data.random_split(trainDict,[train_size,test_size])
train_loader = get_train_loader(train_dataset)
test_loader = get_test_loader(test_dataset)
print(test_size)
print(train_size)

In [None]:
# Test use RandFlip
train_dataset = DatasetsMRI(
        data_root=trainDict, 
        transform=train_transform())
img,lbl,idx = train_dataset[0]
print(np.shape(img))
plt.imshow(img[0, :, 100, :])

In [None]:
# Test unused RandFlip
train_dataset = DatasetsMRI(
        data_root=trainDict, 
        transform=train_transform())
img,lbl,idx = train_dataset[0]
print(np.shape(img))
plt.imshow(img[0, :, 100, :])

### Other Code (unpairing)

In [None]:
normalized = n2_img.get_fdata().astype(np.float32) / n2_img.get_fdata().max()
print(normalized.shape)
print(np.max(normalized))

In [None]:
# pick one image from DecathlonDataset to visualize and check the 4 channels
val_data_example = val_ds[2]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example["image"][i, :, :, 60].detach().cpu(), cmap="gray")
plt.show()
# also visualize the 3 channels label corresponding to this image
print(f"label shape: {val_data_example['label'].shape}")
plt.figure("label", (18, 6))
for i in range(3):
    plt.subplot(1, 3, i + 1)
    plt.title(f"label channel {i}")
    plt.imshow(val_data_example["label"][i, :, :, 60].detach().cpu())
plt.show()