In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import gzip, shutil
import nibabel as nib

import torch.nn.functional as F
from torch import nn as nn
from torch.autograd import Variable
from torch.nn import MSELoss, SmoothL1Loss, L1Loss

### Data collection

In [None]:
def get_IDs():
  root = '/nobackup/sc19rw/Train/HGG' #change this to correct path
  dirlistHGG = [ ("HGG/"+ item + "/" + item) for item in os.listdir(root) if os.path.isdir(os.path.join(root, item)) ]
  root = '/nobackup/sc19rw/Train/LGG' #change this to correct path
  dirlistLGG = [ ("LGG/"+ item + "/" + item) for item in os.listdir(root) if os.path.isdir(os.path.join(root, item)) ]
  total = dirlistHGG + dirlistLGG
  return total

In [None]:
#MRI_ids = get_IDs()
#random.shuffle(MRI_ids)

In [None]:
#MRI_ids = np.array(MRI_ids)
#np.savez("/content/drive/My Drive/Project Msc/MRI_ids", MRI_ids)

In [None]:
MRI_ids = np.load("/home/home01/sc19rw/MRI_ids.npz") #make sure you use the .npz!
MRI_ids = MRI_ids['arr_0']

In [None]:
import pandas as pd
import random


root = '/nobackup/sc19rw/Train/'

data = {
    'image_id': MRI_ids,
    't1_path': [root + MRI_id + "_t1"+ ".nii" for MRI_id in MRI_ids],
    't1ce_path': [root + MRI_id + "_t1ce" + ".nii" for MRI_id in MRI_ids],
    'flair_path': [root + MRI_id + "_flair" + ".nii" for MRI_id in MRI_ids],
    't2_path': [root + MRI_id + "_t2" + ".nii" for MRI_id in MRI_ids],
    'seg_path': [root + MRI_id + "_seg" + ".nii" for MRI_id in MRI_ids],
}

data_df = pd.DataFrame(data, columns=['image_id', 't1_path', 't1ce_path', 'flair_path', 't2_path', 'seg_path'])

In [None]:
from torch.utils.data import Dataset
import torch

class BRATS_DATA(Dataset):
    """ BRATS custom dataset compatible with torch.utils.data.DataLoader. """
    
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __getitem__(self, index):

        MRI_id = self.df['image_id'][index] 
        t1_path = self.df['t1_path'][index]
        t1ce_path = self.df['t1ce_path'][index]
        flair_path = self.df['flair_path'][index]
        t2_path = self.df['t2_path'][index]


        t1_MRI = nib.load(t1_path)
        t1ce_MRI = nib.load(t1ce_path)
        flair_MRI = nib.load(flair_path)
        t2_MRI = nib.load(t2_path)

        affines = [t1_MRI.affine, t1ce_MRI.affine, flair_MRI.affine, t2_MRI.affine]

        t1_MRI = t1_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        t1ce_MRI = t1ce_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        flair_MRI = flair_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        t2_MRI = t2_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)

        input_tensor = torch.cat((torch.from_numpy(t1_MRI), torch.from_numpy(t1ce_MRI), torch.from_numpy(flair_MRI), torch.from_numpy(t2_MRI)), 0)
        

        return input_tensor, affines, MRI_id

    def __len__(self):
        return len(self.df)

In [None]:
dataset = BRATS_DATA(
    df=data_df[:len(data_df)].reset_index(drop=True),
)

In [None]:
# ADD DATA LOADERS
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0, 
)

In [None]:
import statistics 

def normalise_MRI(modality, dimx=240, dimy=240, dimz=155):  #ONE OFF THEN SAVE MRIs do this on full MRIS
  all_values = [] #contains values of all non brain voxels #each modality idenpendently
  for j in range(dimx):
    for k in range(dimy):
      for l in range(dimz):
        if modality[j][k][l].item() != 0.:
          all_values.append(modality[j][k][l].item())
  mean = statistics.mean(all_values)
  stddev = statistics.stdev(all_values)

  for j in range(dimx):
    for k in range(dimy):
      for l in range(dimz):
        if modality[j][k][l].item() != 0.:
          modality[j][k][l] = (modality[j][k][l]- mean) / stddev
  print(torch.max(modality))
  return modality

In [None]:
def save_normalised():
  for patient, (input_tensor, affines, MRI_ID) in enumerate(data_loader): #ONLY WORKS WITH BATCH SIZE 1
    if patient > 283:
      for i in range(4):
        if i == 0:
          t1_MRI = normalise_MRI(input_tensor[0][i])
          ni_img = nib.Nifti1Image(t1_MRI.numpy(), affines[i].reshape(4, 4))
          nib.save(ni_img, '/nobackup/sc19rw/Train/' + str(MRI_ID[0]) +'_t1_norm.nii')
        elif i == 1: 
          t1ce_MRI = normalise_MRI(input_tensor[0][i])
          ni_img = nib.Nifti1Image(t1ce_MRI.numpy(), affines[i].reshape(4, 4))
          nib.save(ni_img, '/nobackup/sc19rw/Train/' + str(MRI_ID[0]) +'_t1ce_norm.nii')
        elif i == 2: 
          flair_MRI = normalise_MRI(input_tensor[0][i])
          ni_img = nib.Nifti1Image(flair_MRI.numpy(), affines[i].reshape(4, 4))
          nib.save(ni_img, '/nobackup/sc19rw/Train/' + str(MRI_ID[0]) +'_flair_norm.nii')
        elif i == 3: 
          t2_MRI = normalise_MRI(input_tensor[0][i])
          ni_img = nib.Nifti1Image(t2_MRI.numpy(), affines[i].reshape(4, 4))
          nib.save(ni_img, '/nobackup/sc19rw/Train/' + str(MRI_ID[0]) +'_t2_norm.nii')
      print(patient)

In [None]:
save_normalised()

tensor(3.5815, dtype=torch.float64)
tensor(12.3698, dtype=torch.float64)
tensor(10.3593, dtype=torch.float64)
tensor(5.3255, dtype=torch.float64)
284


### Testing

In [None]:
MRI_ids = np.load("/home/home01/sc19rw/MRI_ids.npz") #make sure you use the .npz!
MRI_ids = MRI_ids['arr_0']

In [None]:
import pandas as pd
import random



#random.shuffle(MRI_ids)
root = '/nobackup/sc19rw/Train/'

data = {
    'image_id': MRI_ids,
    't1_path': [root + MRI_id + "_t1_norm"+ ".nii" for MRI_id in MRI_ids],
    't1ce_path': [root + MRI_id + "_t1ce_norm" + ".nii" for MRI_id in MRI_ids],
    'flair_path': [root + MRI_id + "_flair_norm" + ".nii" for MRI_id in MRI_ids],
    't2_path': [root + MRI_id + "_t2_norm" + ".nii" for MRI_id in MRI_ids],
    'seg_path': [root + MRI_id + "_seg_norm" + ".nii" for MRI_id in MRI_ids],
}

data_df = pd.DataFrame(data, columns=['image_id', 't1_path', 't1ce_path', 'flair_path', 't2_path', 'seg_path'])

In [None]:
from torch.utils.data import Dataset
import torch

class BRATS_DATA(Dataset):
    """ BRATS custom dataset compatible with torch.utils.data.DataLoader. """
    
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __getitem__(self, index):

        MRI_id = self.df['image_id'][index] 
        t1_path = self.df['t1_path'][index]
        t1ce_path = self.df['t1ce_path'][index]
        flair_path = self.df['flair_path'][index]
        t2_path = self.df['t2_path'][index]


        t1_MRI = nib.load(t1_path)
        t1ce_MRI = nib.load(t1ce_path)
        flair_MRI = nib.load(flair_path)
        t2_MRI = nib.load(t2_path)

        affines = [t1_MRI.affine, t1ce_MRI.affine, flair_MRI.affine, t2_MRI.affine]

        t1_MRI = t1_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        t1ce_MRI = t1ce_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        flair_MRI = flair_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)
        t2_MRI = t2_MRI.get_fdata()[:,:,:].reshape(1, 240, 240, 155)

        input_tensor = torch.cat((torch.from_numpy(t1_MRI), torch.from_numpy(t1ce_MRI), torch.from_numpy(flair_MRI), torch.from_numpy(t2_MRI)), 0)
        

        return input_tensor, affines, MRI_id

    def __len__(self):
        return len(self.df)

In [None]:
dataset = BRATS_DATA(
    df=data_df[:len(data_df)].reset_index(drop=True),
)

In [None]:
# ADD DATA LOADERS
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0, 
)

In [None]:
for patient, (input_tensor, affines, MRI_ID) in enumerate(data_loader): #ONLY WORKS WITH BATCH SIZE 1
  print(MRI_ID)
  print(torch.max(input_tensor))

('HGG/Brats18_TCIA02_117_1/Brats18_TCIA02_117_1',)
tensor(7.6134, dtype=torch.float64)
('LGG/Brats18_TCIA10_625_1/Brats18_TCIA10_625_1',)
tensor(11.9287, dtype=torch.float64)
('HGG/Brats18_TCIA08_280_1/Brats18_TCIA08_280_1',)
tensor(14.8424, dtype=torch.float64)
('HGG/Brats18_CBICA_BFB_1/Brats18_CBICA_BFB_1',)
tensor(12.6173, dtype=torch.float64)
('HGG/Brats18_TCIA08_242_1/Brats18_TCIA08_242_1',)
tensor(13.9873, dtype=torch.float64)
('HGG/Brats18_TCIA02_608_1/Brats18_TCIA02_608_1',)
tensor(12.9000, dtype=torch.float64)
('HGG/Brats18_CBICA_AQR_1/Brats18_CBICA_AQR_1',)
tensor(11.9386, dtype=torch.float64)
('HGG/Brats18_TCIA01_460_1/Brats18_TCIA01_460_1',)
tensor(16.9064, dtype=torch.float64)
('HGG/Brats18_TCIA02_471_1/Brats18_TCIA02_471_1',)
tensor(12.0586, dtype=torch.float64)
('HGG/Brats18_2013_7_1/Brats18_2013_7_1',)
tensor(11.5081, dtype=torch.float64)
('HGG/Brats18_CBICA_AVV_1/Brats18_CBICA_AVV_1',)
tensor(7.7101, dtype=torch.float64)
('HGG/Brats18_CBICA_ABM_1/Brats18_CBICA_ABM_1',)