# Things to keep in mind!!!!
- pt no. 109(no FLAIR) & **123(no T1w)** & 709(no FLAIR) should be excluded!
- choice : **zero-mean normalization & 3D normalization** (following sangwook's advice)
- For the 1st trial, will use a single modality. **"T1w"**
- Data augmentation : 4*3(flipped, high sigma deform, very very weak gaussian noise) per each pt.
> In total, **7592 images** will use as the train set. (= (4*3+1)*(585-1))  

# 0. Load required libraries & csv files.

In [None]:
# Install required packages
! pip install natsort

In [None]:
import pandas as pd 
import matplotlib.pyplot as plt

import os 
import cv2
import numpy as np
import glob
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from torch.utils.data import Dataset #, DataLoader
import torch
import SimpleITK as sitk

from skimage.transform import resize
# to sort file names by its order
from natsort import natsorted
%matplotlib inline

# To get a center of mass for an image (Later, will get patch images from the com.)
from skimage import filters
from skimage.measure import regionprops

# Global variables
modalities = ['FLAIR', 'T1w', 'T1wCE', 'T2w']

In [None]:
# Load y : labels & preds
train_labels = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
preds_labels = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')
#train_labels = pd.read_csv('../png_data/train_labels.csv')
#preds_labels = pd.read_csv('../png_data/sample_submission.csv')

print(f'Number of patients = {len(train_labels)}'.format())
train_labels.head()

# 1. Create "DataGenerator"

In [None]:
# Reference: https://www.kaggle.com/ayuraj/brain-tumor-eda-and-interactive-viz-with-w-b
def ReadMRI(path, voi_lut = True, fix_monochrome = True):
    # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    dicom = pydicom.read_file(path)
    
    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
               
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    # min-max normalization
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    
    return data

def load_imgs(idx, train=True): 
    images_idx = {}
    for modal in modalities:
        images_modal = []
        if train:
            file_path_list = glob.glob('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/'+idx+'/'+modal+'/*')
        else:
            file_path_list = glob.glob('../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/'+idx+'/'+modal+'/*')
            
        # Should be sorted again. 
        file_path_list = natsorted(file_path_list)
        
        for file_path in file_path_list:
            image = ReadMRI(file_path)
            # In case, a direc is empty!
            if len(image) == 0:
                print('yes')
                image = np.zeros((1,256,256)) # pt no. 109(no FLAIR) & 123(no T1w) & 709(no FLAIR) >> excluded! 
            images_modal.append(image)
        images_idx[modal] = np.array(images_modal)
        
    return images_idx

In [None]:
# Check sample image - pt no.00000
# Not registrated.
# Just to check how long would "load_imgs" be. 
# %%time 위에 주석 달면 time 체크 불가....!?
for i in range(1):
    idx = str(train_labels.BraTS21ID[i]).zfill(5)
    print(idx)
    imgs = load_imgs(idx)
    img_ = imgs[modalities[0]]#.mean(axis=0)
    print(img_.shape)
# Again, just to check sample
fig, ax = plt.subplots(2,2, figsize=(10,10))
for i in range(2):
    for j in range(2):
        m = ax[i,j].imshow(imgs[modalities[2*i+j]].mean(axis=0), cmap='gray')
        ax[i,j].set_title(modalities[2*i+j])
plt.show()

In [None]:
# 1. Create a data loader with N4biasFieldCorrectionImageFilter
class PreprocessedImage2Dver(Dataset):
    def __init__(self, list_BraTS21ID, list_labels=None,
                dim=(512, 512), n_modals=len(modalities), n_classes=2,
                 num_slices_from_center:int=3,
                 is_train=True, transform = None): # For single pt. 
        self.dim = dim
        ### ????? 1. How to do batch norm for this task..?
        #self.batch_size = batch_size
        self.list_labels = list_labels
        self.is_train = (list_labels is not None)
        self.list_BraTS21ID = list_BraTS21ID
        self.n_modals = n_modals # number of modalities
        self.num_slices_from_center = num_slices_from_center
        
    
    def __getitem__(self, index):
        BraTS21ID_temp = self.list_BraTS21ID[index] #self.list_BraTS21ID[index*self.batch_size:(index+1)*self.batch_size] # index로 적어준 batch 만 진행!
        
        X = self.__data_generation(BraTS21ID_temp)
        
        if self.is_train:
            y = self.list_labels[index] #self.list_labels[index*self.batch_size:(index+1)*self.batch_size]
            return np.array(X), np.array(y)
        else:
            return np.array(X)
    
    def __data_generation(self, BraTS21ID_temp):
        new_imgs = np.zeros((self.num_slices_from_center*2,  *self.dim, self.n_modals))
        #print(new_imgs.shape)
        #new_imgs = None
        
        idx = str(BraTS21ID_temp).zfill(5)
        imgs = load_imgs(idx, train=self.is_train) # imgs = {'FLAIR': ~, 'T1w': ~, 'T1wCE': ~} lib다. 
        
        index_modal = 0
        for modal in modalities:
            corrct_norm_imgs_modal = []

            img_modal = imgs[modal] #.shape :  ex. (288, 256, 256) 
            
            if img_modal.shape[0] < self.num_slices_from_center *2 :
                print('The number of slice is smaller than total number of slices!!')
                break
            
            central_slices = int(img_modal.shape[0]/2)
            
            list_slices = list(range(central_slices-self.num_slices_from_center, central_slices+self.num_slices_from_center))
            for slice_i in list_slices:
                img_slice_i = resize(img_modal[slice_i,:,:], self.dim)  #According to skimage API, "2-D interpolation".
                #img_slice_i = cv2.resize(img_modal[slice_i,:,:], dsize= self.dim, interpolation = cv2.INTER_LINEAR) #???
                img_slice_i_ndarray = np.array(img_slice_i, dtype = 'float32')

                # 1. Removing radiofrequency inhomogeneity using N4BiasFieldCorrection
                # ref : https://www.kaggle.com/josepc/rsna-effnet
                inputImage = sitk.GetImageFromArray(img_slice_i_ndarray)
                maskImage = sitk.GetImageFromArray((img_slice_i_ndarray>0.1)*1) #sitk.OtsuThreshold(inputImage, 0,1,200) 
                inputImage = sitk.Cast(inputImage, sitk.sitkFloat32) #?? 왜 32로?
                maskImage = sitk.Cast(maskImage, sitk.sitkUInt8) #?? 왜 8로?
                corrector = sitk.N4BiasFieldCorrectionImageFilter()
                numberFittingLevels = 4 # ??
                maxIter = 100 # ??
                if maxIter is not None:
                    corrector.SetMaximumNumberOfIterations([maxIter]*numberFittingLevels) # ??
                corrected_image = corrector.Execute(inputImage, maskImage)
                
                corrected_image = sitk.GetArrayFromImage(corrected_image)
                
                '''max_2D = np.amax(corrected_image) 
                min_2D = np.amin(corrected_image)
                mean_2D = np.mean(corrected_image)
                normalized_corrected_image = (corrected_image-min_2D)/(max_2D-min_2D)'''
                
                corrct_norm_imgs_modal.append(corrected_image)
            
    
            corrct_norm_imgs_modal = np.array(corrct_norm_imgs_modal)
            
            # 2. "Zero-mean" Normalization (Normalization type does effect models' performances)
            mean_3D = np.mean(corrct_norm_imgs_modal)
            std_3D = np.std(corrct_norm_imgs_modal)
            
            corrct_norm_imgs_modal = (corrct_norm_imgs_modal-mean_3D)/std_3D
            
            new_imgs[:, :, :, index_modal] = corrct_norm_imgs_modal
            
            index_modal += 1
        
        return new_imgs # shape = (num_slices_from_center*2, *dim, n_modals) !!! not 3D. 4D.

## Check a sample preprocessed image
### - num_slices_from_center:int=3
### - for "pt no.00002" 
### - for 2D slice image at center

In [None]:
# Do train_test_split. 
# - Train : Test = 8:2 / Stratified random sampling / random_state = 42
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(train_labels.BraTS21ID, train_labels.MGMT_value,
                                                 test_size = .2, random_state = 42, stratify = train_labels.MGMT_value)

dim = (256, 256) 
train_set = PreprocessedImage2Dver(X_train, y_train, dim=dim)
val_set = PreprocessedImage2Dver(X_val, y_val, dim=dim)
test_set = PreprocessedImage2Dver(preds_labels.BraTS21ID,  dim = dim)

sample_preprocessed_imgs = train_set[1][0] # pt no.00002

center_slice_idx = int(sample_preprocessed_imgs.shape[0]/2)
#print(center_slice_idx) # In this cas, returns 3. 
img_at_center = sample_preprocessed_imgs[center_slice_idx] # slice at center

# Again, just to check sample preprocessed image.
fig, ax = plt.subplots(2,2, figsize=(10,10))
for i in range(2):
    for j in range(2):
        m = ax[i,j].imshow(img_at_center[:, :, 2*i+j],  cmap = 'gray')
        ax[i,j].set_title(modalities[2*i+j])
plt.show()

In [None]:
# Check a sampled space of a sample_img (FLAIR ver). 
print(img_at_center[100:150, 100:130, 0])

# 2. Do "Data augmentation" 

## 1) Horizontal flipped

In [None]:
def do_hflip_2D(img): # For 3D image.
    return np.flip(img, axis=1)

## 2) Gaussian noise applied
#### - As the standard deviation of **"sample_T1w"** image is about 0.79,
####   (>> print(np.std(sample_T1w)) #0.78577...)
####   Gaussian noise from mean=0 & std=0.1. (To give noises as small as possible)
#### ????? How to choose mean & std of gaussian noise?

In [None]:
def apply_gaussian_noise(img, mean:float=0., std:float=.05):
    gaussian_noise = np.random.normal(mean, std, img.shape)
    img_with_noise = img + gaussian_noise
    return img_with_noise

## 3) Deformed (high sigma)

In [None]:
# ref : https://www.hj-chung.com/post/elastic-distortion/
def do_elastic_distortion(img, rows, cols, sigma=200., alpha=.5):
    #true_dst = np.zeros((rows,cols,ch))

    # Sampling from Unif(-1, 1)
    dx = np.random.uniform(-1,1,(rows,cols))
    dy = np.random.uniform(-1,1,(rows,cols))

    # STD of gaussian kernel
    sig = sigma

    dx_gauss = cv2.GaussianBlur(dx, (7,7), sig)
    dy_gauss = cv2.GaussianBlur(dy, (7,7), sig)

    n = np.sqrt(dx_gauss**2 + dy_gauss**2) # for normalization

    # Strength of distortion
    alpha = alpha

    ndx = alpha * dx_gauss/ n
    ndy = alpha * dy_gauss/ n

    indy, indx = np.indices((rows, cols), dtype=np.float32)

    # dst_img = cv2.remap(img,ndx - indx_x, ndy - indx_y, cv2.INTER_LINEAR)

    map_x = ndx + indx
    map_x = map_x.reshape(rows, cols).astype(np.float32)
    map_y = ndy + indy
    map_y = map_y.reshape(rows, cols).astype(np.float32)

    dst = cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR)
    
    return dst

In [None]:
# Example1. Flipped image
sample_T1w = img_at_center[:,:,1]
hflipped_sample_T1w = do_hflip_2D(sample_T1w)

plt.figure(figsize=(10,10))
plt.subplot(3,2,1)
plt.imshow(sample_T1w, cmap='gray')
plt.title('Original T1w image')

plt.subplot(3,2,2)
plt.imshow(hflipped_sample_T1w, cmap='gray')
plt.title('Flipped T1w image')
plt.show()

# --------------------------------------------------------------
# Example2. Gaussian noise applied image
#sample_T1w = img_at_center[:,:,1]
noised_sample_T1w = apply_gaussian_noise(sample_T1w, std=0.1)

plt.figure(figsize=(10,10))
plt.subplot(3,2,3)
plt.imshow(sample_T1w, cmap='gray')
plt.title('Original T1w image')

plt.subplot(3,2,4)
plt.imshow(noised_sample_T1w, cmap='gray')
plt.title('Gaussian noise applied T1w image')
plt.show()

# --------------------------------------------------------------
# Example3. Distorted image (Elastic deformed image)
#sample_T1w = img_at_center[:,:,1]
distored_sample_T1w = do_elastic_distortion(sample_T1w,  sample_T1w.shape[0], sample_T1w.shape[1])

plt.figure(figsize=(10,10))
plt.subplot(3,2,5)
plt.imshow(sample_T1w, cmap='gray')
plt.title('Original T1w image')

plt.subplot(3,2,6)
plt.imshow(noised_sample_T1w, cmap='gray')
plt.title('Deformed(Distorted) T1w image')
plt.show()

# 3) Get 2D patch images. (4 patches per each image)
#### - 4 patch images from the center of mass point + margin.
#### - 2D patch size(w x h) = 100 x 100 (80+20, 80+20)

In [None]:
# Should check shape[0] & shape[1] are even or odd number.
def get_center_of_mass_coordinates(img):
    threshold_value = filters.threshold_otsu(img)
    labeled_foreground = (img > threshold_value).astype(int)
    properties = regionprops(labeled_foreground, img)
    center_of_mass = properties[0].centroid
    center_of_mass = (int(x) for x in center_of_mass) # Float to int
    return center_of_mass

def get_2D_patches_per_vertex(img, 
                              central_patch_size:tuple=(80,80), margin_size:tuple=(20,20)):
    orig_h, orig_w = img.shape
    com_h, com_w = get_center_of_mass_coordinates(img)
    
    left_top_patch = img[com_h-int(central_patch_size[0]/2)-margin_size[0]:com_h+int(central_patch_size[0]/2),\
                        com_w-int(central_patch_size[1]/2)-margin_size[1]:com_w+int(central_patch_size[1]/2)] 
    right_top_patch = img[com_h-int(central_patch_size[0]/2)-margin_size[0]:com_h+int(central_patch_size[0]/2),\
                        com_w-int(central_patch_size[1]/2):com_w+int(central_patch_size[1]/2)+margin_size[1]] 
    left_bottom_patch = img[com_h-int(central_patch_size[0]/2):com_h+int(central_patch_size[0]/2)+margin_size[0],\
                        com_w-int(central_patch_size[1]/2)-margin_size[1]:com_w+int(central_patch_size[1]/2)] 
    right_bottom_patch = img[com_h-int(central_patch_size[0]/2):com_h+int(central_patch_size[0]/2)+margin_size[0],\
                        com_w-int(central_patch_size[1]/2):com_w+int(central_patch_size[1]/2)+margin_size[1]] 
    
    # In case, the generated patch size is wrong. Print warning message. 
    if left_top_patch.shape[0] != central_patch_size[0]+margin_size[0]:
        print('Check the size of patch & margin! Dose not match 100 x 100!')
        
    #print(left_top_patch.shape)
    #print(right_top_patch.shape)
    #print(left_bottom_patch.shape)
    #print(right_bottom_patch.shape)
    
    return (left_top_patch, right_top_patch, left_bottom_patch, right_bottom_patch)

In [None]:
# Example
left_top_patch, right_top_patch, left_bottom_patch, right_bottom_patch = get_2D_patches_per_vertex(sample_T1w)

plt.figure(figsize=(10,10))
plt.imshow(sample_T1w, cmap='gray')
plt.show()

plt.figure(figsize=(10,10))
plt.subplot(2,2,1)
plt.imshow(left_top_patch, cmap='gray')
plt.title('Left top cropped image')

plt.subplot(2,2,2)
plt.imshow(right_top_patch, cmap='gray')
plt.title('Right top cropped image')

plt.subplot(2,2,3)
plt.imshow(left_bottom_patch, cmap='gray')
plt.title('Left bottom cropped image')

plt.subplot(2,2,4)
plt.imshow(right_bottom_patch, cmap='gray')
plt.title('Right bottom cropped image')

plt.show()

## 4) Combine above augmentation procedures

In [None]:
def get_4x3_augmented_2D_imgs(img): # return : dictionary
    # 1. Horizontal flipped image
    hflipped_img = do_hflip_2D(img)
    
    #2. Gaussian noise applied image
    gaussian_noise_img = apply_gaussian_noise(img)
    
    #3. Deformed image
    deformed_img = do_elastic_distortion(img,  img.shape[0], img.shape[1], sigma=200, alpha=.5)
    
    augmented_imgs = {'original_img':img,
                     'hflipped_imgs':list(get_2D_patches_per_vertex(hflipped_img)),
                     'gaussian_noised_imgs':list(get_2D_patches_per_vertex(gaussian_noise_img)),
                     'deformed_imgs':list(get_2D_patches_per_vertex(deformed_img))}
    return augmented_imgs

def plot_cropped_imgs_together(imgs:list):

    plt.figure(figsize=(10,10))
    plt.subplot(2,2,1)
    plt.imshow(imgs[0], cmap='gray')
    plt.title('Left top cropped image')

    plt.subplot(2,2,2)
    plt.imshow(imgs[1], cmap='gray')
    plt.title('Right top cropped image')

    plt.subplot(2,2,3)
    plt.imshow(imgs[2], cmap='gray')
    plt.title('Left bottom cropped image')

    plt.subplot(2,2,4)
    plt.imshow(imgs[3], cmap='gray')
    plt.title('Right bottom cropped image')

    plt.show()

In [None]:
augmented_imgs = get_4x3_augmented_2D_imgs(sample_T1w)

hflipped_imgs = augmented_imgs['hflipped_imgs']
gaussian_noised_imgs = augmented_imgs['gaussian_noised_imgs']
deformed_imgs = augmented_imgs['deformed_imgs']

print('Horizontally flipped images')
plot_cropped_imgs_together(hflipped_imgs)

print('\n')
print('Gaussian noise applied images')
plot_cropped_imgs_together(gaussian_noised_imgs)

print('\n')
print('Deformed images')
plot_cropped_imgs_together(deformed_imgs)