# Data generation notebook : BRATS

In [None]:
### Base ###
import os
import numpy as np 
import torch 
import torch.nn as nn
from torch.optim import Adam
import fnmatch
from torch.utils.data import TensorDataset, DataLoader
import itertools
import math
from sklearn.decomposition import PCA

### Visualization ###
#import seaborn as sns
#sns.set(color_codes=True)
import matplotlib.cm as cm
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from matplotlib import rc
rc('text', usetex=True)
rc('font', **{'family':'serif','serif':['Palatino']})

print('Numpy version : ', np.version.version)

In [None]:
DATA_DIR = '/Users/paul.vernhet/Workroom/Software/Data/MICCAI_dataset'

In [None]:
### Advanced ##
print('Conda environment : ', os.environ['CONDA_DEFAULT_ENV'])
from scipy.ndimage import gaussian_filter
import PIL.Image as pimg
import nibabel as nib
from multiprocessing import Pool
import itk
from scipy import ndimage as nd

# 1. Normalize 3D MR images 

In [None]:
args = []

path_to_dataset__in = os.path.join(DATA_DIR, '1_brats_2020')
path_to_dataset__out = os.path.join(DATA_DIR, '2_t1ce_normalized_ppm')
grades = ['HGG', 'LGG']
img_type = 't1ce'

path_to_file__ref = os.path.join(DATA_DIR, '1_brats_2020/reference/3_colin27_t1_tal_lin_brain.nii')
# path_to_file__ref = os.path.join(DATA_DIR, '1_brats_2020/HGG/BraTS19_2013_2_1/BraTS19_2013_2_1_t1ce.nii.gz')

if not os.path.isdir(path_to_dataset__out): 
    os.mkdir(path_to_dataset__out)

### TRAIN ###
t = ('1_training', 'train') 
path_to_t = os.path.join(path_to_dataset__out, t[1])
if not os.path.exists(path_to_t): 
    os.makedirs(path_to_t)
for grade in grades: 
    path_to_subjects = os.path.join(path_to_dataset__in, t[0], grade)
    subject_ids = [elt for elt in os.listdir(path_to_subjects) if elt[:5] == 'BraTS']
    for subject_id in subject_ids:
        path_to_file__in = os.path.join(path_to_dataset__in, t[0], grade, subject_id, subject_id + '_%s.nii.gz' % img_type)
        path_to_file__out = os.path.join(path_to_dataset__out, t[1], grade.lower() + '_' + subject_id[8:] + '_' + img_type)
        args.append((subject_id, path_to_file__in, path_to_file__out + '.nii.gz'))

### TEST ###
t = ('2_validation', 'test')
path_to_t = os.path.join(path_to_dataset__out, t[1])
if not os.path.isdir(path_to_t):
    os.mkdir(path_to_t)
path_to_subjects = os.path.join(path_to_dataset__in, t[0])
subject_ids = [elt for elt in os.listdir(path_to_subjects) if elt[:5] == 'BraTS']
for subject_id in subject_ids:
    path_to_file__in = os.path.join(path_to_dataset__in, t[0], subject_id, subject_id + '_%s.nii.gz' % img_type)
    path_to_file__out = os.path.join(path_to_dataset__out, t[1], subject_id[8:] + '_' + img_type)
    args.append((subject_id, path_to_file__in, path_to_file__out + '.nii.gz'))

### RUN ###
PixelType = itk.F
ImageType = itk.Image[PixelType, 3]

img_ref = itk.imread(path_to_file__ref, PixelType)
# img_ref = itk.RescaleIntensityImageFilter(img_ref, ttype=(ImageType, ImageType))
# img_ref = itk.cast_image_filter(img_ref)

def launch(args):
    subject_id, path_in, path_out = args
    print(subject_id)
    img_in = itk.imread(path_in, PixelType)
#     img_in = itk.RescaleIntensityImageFilter(img_in, ttype=(ImageType, ImageType))
    img_out = itk.HistogramMatchingImageFilter(img_in, img_ref)
    img_out = itk.RescaleIntensityImageFilter(img_out)
    itk.imwrite(img_out, path_out)

for arg in args: 
    launch(arg)
    
# with Pool(os.cpu_count()) as pool:
#     pool.map(launch, args)

# 2. Extract a slice and stock it in png format

In [None]:
path_to_dataset = '/Users/alexandre.bone/Workspace/2020_MICCAI/2_datasets/2_t1ce_normalized_colin'
t = 'train'
filename = 'hgg_2013_2_1_t1ce.nii.gz'

# margin_u1 = 45
# margin_u2 = 45
# margin_v1 = 35
# margin_v2 = 20

margin_u1 = 35
margin_u2 = 35
margin_v1 = 25
margin_v2 = 10

# img_size = (64, 64)
img_size = (128, 128)

def reshape(image_data):
    dsfactor = [w / float(f) for w, f in zip(img_size, image_data.shape)]
    return nd.interpolation.zoom(image_data, zoom=dsfactor, order=1)

path_to_file = os.path.join(path_to_dataset, t, filename)
img = nib.load(path_to_file).get_data()
img = np.transpose(img[margin_u1:-margin_u2, margin_v1:-margin_v2, 80][::-1, ::-1])
print(img.shape)
img = reshape(img)
print(img.shape)

figsize = 5
f = plt.figure(figsize=(figsize, figsize))
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()

# plt.hist(img.reshape(-1, 1))
# plt.show()

In [None]:
path_to_dataset__in = '/Users/alexandre.bone/Workspace/2020_MICCAI/2_datasets/2_t1ce_normalized_colin'
path_to_dataset__out = '/Users/alexandre.bone/Workspace/2020_MICCAI/2_datasets/7_t1ce_normalized_colin_slice80'

slice_id = 80
# margin_u1 = 45
# margin_u2 = 45
# margin_v1 = 35
# margin_v2 = 20

margin_u1 = 35
margin_u2 = 35
margin_v1 = 25
margin_v2 = 10

img_size = (128, 128)

if not os.path.isdir(path_to_dataset__out): 
    os.mkdir(path_to_dataset__out)

img_average = img * 0.0
for t in ['train', 'test']: 
    path_to_t__in = os.path.join(path_to_dataset__in, t)
    path_to_t__out = os.path.join(path_to_dataset__out, t)
    if not os.path.isdir(path_to_t__out): 
        os.mkdir(path_to_t__out)

    filenames = [elt for elt in os.listdir(path_to_t__in) if '.nii.gz' in elt]
    for k, filename in enumerate(filenames):
        path_to_file__in = os.path.join(path_to_t__in, filename)
        path_to_file__out = os.path.join(path_to_t__out, ('%03d_' % k) + filename)

        img = nib.load(path_to_file__in).get_data()
        img = np.transpose(img[margin_u1:-margin_u2, margin_v1:-margin_v2, slice_id][::-1, ::-1])
        img = reshape(img)

        if t == 'train': 
            img_average += img / float(len(filenames))
        
        pimg.fromarray(img).save(path_to_file__out[:-7] + '.png')
        
tol = 1e-10
img_average = np.clip(img_average, tol, 255.0 - tol).astype('uint8')
pimg.fromarray(img_average).save(os.path.join(path_to_dataset__out, 'average.png'))

In [None]:
path_to_file = '/Users/alexandre.bone/Softwares/deepshape/examples/brains/data_128/cn/s0074.npy'
img = np.load(path_to_file)

img = img[:, :, 60]
plt.imshow(img)
plt.show()

In [None]:
img.shape