In [None]:
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt 
import pandas as pd
from keras.utils import to_categorical
import torch 
from torch import optim,nn 
import torch.nn.functional as F
from torchvision import datasets,transforms as T
from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler

In [None]:
HOME_DIR = 'Task01_BrainTumour/'
DATA_DIR = HOME_DIR

def load_case(image_nifty_file, label_nifty_file):
    
    image = np.array(nib.load(image_nifty_file).get_fdata())
    label = np.array(nib.load(label_nifty_file).get_fdata())
    
    return image,label

In [None]:
def get_labeled_image(image, label,is_categorical = False):
    if not is_categorical:
        label = to_categorical(label, num_classes = 4).astype(np.uint8)
        
    image = cv2.normalize(image[:,:,:,0], None, alpha = 0,beta = 255, norm_type = cv2.NORM_MINMAX,dtype = cv2.CV_32F)
    
    labeled_image = np.zeros_like(label[:,:,:,1:])
    labeled_image[:,:,:,0] = image * (label[:,:,:,0])
    labeled_image[:,:,:,1] = image * (label[:,:,:,0])
    labeled_image[:,:,:,2] = image * (label[:,:,:,0])
    
    labeled_image = labeled_image + label[:,:,:,1:]*255
    
    return labeled_image

In [None]:
image,label = load_case(DATA_DIR + "imagesTr/BRATS_001.nii.gz", DATA_DIR + "labelsTr/BRATS_001.nii.gz")
image = get_labeled_image(image,label)

plt.imshow(image[:,:,77])

In [None]:
def get_sub_volume(image,label,orig_x = 240,orig_y = 240,orig_z = 155,
                  output_x = 160, output_y = 160, output_z = 16,num_classes = 4,max_tries = 1000,background_threshold= 0.95):
    
    X = None
    y = None
    
    tries = 0
    
    while tries < max_tries:
        
        start_x = np.random.randint(0,orig_x-output_x+1)
        start_y = np.random.randint(0,orig_y-output_y+1)
        start_z = np.random.randint(0,orig_z-output_z+1)
        
        y = label[start_x: start_x + output_x,
                  start_y: start_y + output_y,
                  start_z: start_z + output_z]
        
        y = to_categorical(y,num_classes)
        
        bgrd_ratio = np.sum(y[:,:,:,0])/(output_x*output_y*output_z)
        
        tries +=1
        
        if bgrd_ratio < background_threshold:
            
            X = np.copy(image[start_x: start_x + output_x,
                              start_y: start_y + output_y,
                              start_z: start_z + output_z, :])
            X = np.moveaxis(X,3,0)
            y = np.moveaxis(y,3,0)
            y = y[1:, :, :, :]
            
            return X,y
        
    print(f"Tried {tries} times to find a sub-volume. Giving up...")

In [None]:
def standardize(image):
    
    standardized_image = np.zeros(image.shape)
    
    for c in range(image.shape[0]):
        for z in range(image.shape[3]):
            
            image_slice = image[c,:,:,z]
            centered = image_slice - np.mean(image_slice)
            
            if np.std(centered) != 0:
                centered_scaled = centered/np.std(centered)
                
                standardized_image[c,:,:,z] = centered_scaled
                
    return standardized_image

In [None]:
def create_dataset(DATA_DIR,num_patches = 85):

    random_select = np.random.randint(low = 1, high=480, size=num_patches)
    image_data = []
    label_data = []


    for i in random_select:
        
        if i < 10:
            i = "00"+str(i)+".nii.gz"
        elif i>=10 and i<100:
            i = "0"+str(i)+".nii.gz"
        elif i>=100:
            i = str(i)+".nii.gz"
        
        image,label = load_case(DATA_DIR + "imagesTr/BRATS_"+i,DATA_DIR + "labelsTr/BRATS_"+i)
       
        try: 
            X,y = get_sub_volume(image,label)
            X = standardize(X)
            image_data.append(X)
            label_data.append(y)
        except:
            continue

    return image_data,label_data
    

In [None]:
image_data,label_data = create_dataset(DATA_DIR,num_patches = 150)

In [None]:
images_data = torch.FloatTensor(image_data)
label_data = torch.FloatTensor(label_data)

In [None]:
image_data = torch.transpose(image_data,4,2)
label_data = torch.transpose(label_data,4,2)

In [None]:
torch.save(images_data,'X.pt')
torch.save(label_data,'y.pt')