In [2]:
import numpy as np
import pandas as pd
import os
import scipy.ndimage
import matplotlib.pyplot as plt
import warnings

from keras.models import Model
from keras.layers import Input, Dense, Conv3D, MaxPooling3D, UpSampling3D, merge
from keras.optimizers import Adam
from keras import backend as K
import keras
from keras.callbacks import ModelCheckpoint
from keras.layers import SpatialDropout3D
from keras.layers import BatchNormalization
from keras.models import load_model

from batchgenerators.dataloading.data_loader import DataLoader
from batchviewer import view_batch
from batchgenerators.augmentations.crop_and_pad_augmentations import crop
from batchgenerators.augmentations.utils import pad_nd_image
from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms import Compose

#Numbers below gets divided everytime we apply maxpooling; this can create issue when concatenating two models
# NUM_SLIDES=slice_size
# IMG_HEIGHT=slice_size
# IMG_WIDTH=slice_size
# IMG_CHANNELS=1
datadir_path = "./data/05dzgcM/"
name = '05dzgcM'

%matplotlib inline

In [3]:
class DataLoader3D_jfr(DataLoader):
    def __init__(self, data, batch_size, patch_size, num_threads_in_multithreaded, seed_for_shuffle=42,
                 return_incomplete=False, shuffle=True, infinite=True):
        """
        data must be a list of patients as returned by get_list_of_patients (and split by get_split_deterministic)

        patch_size is the spatial size the retured batch will have

        """
        super().__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle,
                         infinite)
        self.patch_size = patch_size
        self.num_modalities = 0
        self.indices = list(range(len(data)))

    @staticmethod
    def load_patient(patient):
        data = np.load(os.path.join(datadir_path,name+'_clean.npy'), mmap_mode="r")
#         data = data[:][np.newaxis]
#         data = np.load(patient + ".npy", mmap_mode="r")
        metadata = []
#         metadata = load_pickle(patient + ".pkl")
        return data, metadata

    def generate_train_batch(self):
        # DataLoader has its own methods for selecting what patients to use next, see its Documentation
        idx = self.get_indices()
        patients_for_batch = [self._data[i] for i in idx]

        # initialize empty array for data and seg
        data = np.zeros((self.batch_size, self.num_modalities, *self.patch_size), dtype=np.float32)
        seg = np.zeros((self.batch_size, 1, *self.patch_size), dtype=np.float32)

        metadata = []
        patient_names = []

        # iterate over patients_for_batch and include them in the batch
        for i, j in enumerate(patients_for_batch):
            patient_data, patient_metadata = self.load_patient(j)
            
            # this will only pad patient_data if its shape is smaller than self.patch_size
            patient_data = pad_nd_image(patient_data, self.patch_size)

            # now random crop to self.patch_size
            # crop expects the data to be (b, c, x, y, z) but patient_data is (c, x, y, z) so we need to add one
            # dummy dimension in order for it to work (@Todo, could be improved)
            patient_data, patient_seg = crop(patient_data[:-1][None], 
                                             patient_data[-1:][None], 
                                             self.patch_size, 
                                             crop_type="random")

            data[i] = patient_data[0]
            seg[i] = patient_seg[0]

            metadata.append(patient_metadata)
            patient_names.append(j)

        return {'data': data, 'seg':seg, 'metadata':metadata, 'names':patient_names}

In [4]:
def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True, deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True, scale=(0.75, 1.25),
            border_mode_data='constant', border_cval_data=0,
            border_mode_seg='constant', border_cval_seg=0,
            order_seg=1, order_data=3,
            random_crop=True,
            p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1
        )
    )

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # brightness transform for 15% of samples
    tr_transforms.append(BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15))

    # Gaussian Noise
    tr_transforms.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
    # thus make the model more robust to it
    tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
                                               p_per_channel=0.5, p_per_sample=0.15))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms

In [5]:
train = [1]
patients=[1]
num_threads_for_brats_example = 8

# train, val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=42)
patch_size = (128, 128, 128)
batch_size = 2
dataloader = DataLoader3D_jfr(train, batch_size, patch_size, num_threads_in_multithreaded=1)
batch = next(dataloader)
# batch viewer can show up to 4d tensors. We can show only one sample, but that should be sufficient here
# view_batch(batch['data'][0], batch['seg'][0])

# first let's collect all shapes, you will see why later
shapes = [DataLoader3D_jfr.load_patient(i)[0].shape[1:] for i in patients]
max_shape = np.max(shapes, 0)
max_shape = np.max((max_shape, patch_size), 0)

# we create a new instance of DataLoader. This one will return batches of shape max_shape. Cropping/padding is
# now done by SpatialTransform. If we do it this way we avoid border artifacts (the entire brain of all cases will
# be in the batch and SpatialTransform will use zeros which is exactly what we have outside the brain)
# this is viable here but not viable if you work with different data. If you work for example with CT scans that
# can be up to 500x500x500 voxels large then you should do this differently. There, instead of using max_shape you
# should estimate what shape you need to extract so that subsequent SpatialTransform does not introduce border
# artifacts
dataloader_train = DataLoader3D_jfr(train, batch_size, max_shape, num_threads_for_brats_example)

# during training I like to run a validation from time to time to see where I am standing. This is not a correct
# validation because just like training this is patch-based but it's good enough. We don't do augmentation for the
# validation, so patch_size is used as shape target here
# dataloader_validation = DataLoader3D_jfr(val, batch_size, patch_size, max(1, num_threads_for_brats_example // 2))

tr_transforms = get_train_transform(patch_size)

# finally we can create multithreaded transforms that we can actually use for training
# we don't pin memory here because this is pytorch specific.
tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example,
                                num_cached_per_queue=3,
                                seeds=None, pin_memory=False)
#     # we need less processes for vlaidation because we dont apply transformations
#     val_gen = MultiThreadedAugmenter(dataloader_validation, None,
#                                      num_processes=max(1, num_threads_for_brats_example // 2), num_cached_per_queue=1,
#                                      seeds=None,
#                                      pin_memory=False)

#     # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training
#     # batches while other things run in the main thread
#     tr_gen.restart()
#     val_gen.restart()

#     # now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate
#     # inifinite examples! Don't do "for batch in tr_gen:"!!!):
#     num_batches_per_epoch = 10
#     num_validation_batches_per_epoch = 3
#     num_epochs = 5
#     # let's run this to get a time on how long it takes
#     time_per_epoch = []
#     start = time()
#     for epoch in range(num_epochs):
#         start_epoch = time()
#         for b in range(num_batches_per_epoch):
#             batch = next(tr_gen)
#             # do network training here with this batch

#         for b in range(num_validation_batches_per_epoch):
#             batch = next(val_gen)
#             # run validation here
#         end_epoch = time()
#         time_per_epoch.append(end_epoch - start_epoch)
#     end = time()
#     total_time = end - start
#     print("Running %d epochs took a total of %.2f seconds with time per epoch being %s" %
#           (num_epochs, total_time, str(time_per_epoch)))

# if you notice that you have CPU usage issues, reduce the probability with which the spatial transformations are
# applied in get_train_transform (down to 0.1 for example). SpatialTransform is the most expensive transform

# if you wish to visualize some augmented examples, install batchviewer and uncomment this
if view_batch is not None:
    for _ in range(4):
        batch = next(tr_gen)
        view_batch(batch['data'][0], batch['seg'][0])
else:
    print("Cannot visualize batches, install batchviewer first. It's a nice and handy tool. You can get it here: "
          "https://github.com/FabianIsensee/BatchViewer")

In [None]:
# Build U-Net model
inputs = keras.layers.Input((NUM_SLIDES,IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
#s = keras.layers.Lambda(lambda x: x / 3095)(inputs)
 
c1 = keras.layers.Conv3D(16, kernel_size=(3,3,3), activation='relu',padding='same')(inputs)
c1 = keras.layers.SpatialDropout3D(0.3)(c1)
c1 = keras.layers.Conv3D(16, (3,3,3), activation='relu',padding='same')(c1)
#GlobalAveragePooling3D -- try this as well
p1 = keras.layers.MaxPooling3D((2,2,2))(c1)
#p1 = BatchNormalization()(p1)
 
c2 = keras.layers.Conv3D(32, (3,3,3), activation='relu',padding='same')(p1)
c2 = keras.layers.SpatialDropout3D(0.3)(c2)
c2 = keras.layers.Conv3D(32, (3,3,3), activation='relu',padding='same')(c2)
p2 = keras.layers.MaxPooling3D((2,2,2))(c2)
#p2 = BatchNormalization()(p2)

c3 = keras.layers.Conv3D(64, (3,3,3), activation='relu',padding='same')(p2)
c3 = keras.layers.SpatialDropout3D(0.3)(c3)
c3 = keras.layers.Conv3D(64, (3,3,3), activation='relu',padding='same')(c3)
p3 = keras.layers.MaxPooling3D((2, 2,2))(c3)
#p3 = BatchNormalization()(p3)

c4 = keras.layers.Conv3D(128, (3,3,3), activation='relu',padding='same')(p3)
c4 = keras.layers.SpatialDropout3D(0.3)(c4)
c4 = keras.layers.Conv3D(128, (3,3,3), activation='relu', padding='same')(c4)
p4 = keras.layers.MaxPooling3D(pool_size=(2,2,2))(c4)
#p4 = BatchNormalization()(p4)

c5 = keras.layers.Conv3D(256, (3,3,3), activation='relu',padding='same')(p4)
c5 = keras.layers.SpatialDropout3D(0.3)(c5)
c5 = keras.layers.Conv3D(256, (3,3,3), activation='relu', padding='same')(c5)
p5 = keras.layers.MaxPooling3D(pool_size=(2,2,2))(c5)
#p5 = BatchNormalization()(p5)


c55 = keras.layers.Conv3D(512, (3,3,3), activation='relu',padding='same')(p5)
c55 = keras.layers.SpatialDropout3D(0.3)(c55)
c55 = keras.layers.Conv3D(512, (3,3,3), activation='relu',padding='same')(c55)
#c55 = BatchNormalization()(c55)


u66 = keras.layers.Conv3DTranspose(256, (2,2,2), strides=(2,2,2), padding='same')(c55)
u66 = keras.layers.concatenate([u66, c5])
c66 = keras.layers.Conv3D(256, (3,3,3), activation='relu',padding='same')(u66)
c66 = keras.layers.SpatialDropout3D(0.3)(c66)
c66 = keras.layers.Conv3D(256, (3,3,3), activation='relu',padding='same')(c66)
#c66 = BatchNormalization()(c66)

u6 = keras.layers.Conv3DTranspose(128, (2,2,2), strides=(2,2,2), padding='same')(c66)
u6 = keras.layers.concatenate([u6, c4])
c6 = keras.layers.Conv3D(128, (3,3,3), activation='relu',padding='same')(u6)
c6 = keras.layers.SpatialDropout3D(0.3)(c6)
c6 = keras.layers.Conv3D(128, (3,3,3), activation='relu',padding='same')(c6)
#c6 = BatchNormalization()(c6)

u7 = keras.layers.Conv3DTranspose(64, (2, 2,2), strides=(2, 2,2), padding='same')(c6)
u7 = keras.layers.concatenate([u7, c3])
c7 = keras.layers.Conv3D(64, (3, 3,3), activation='relu',padding='same')(u7)
c7 = keras.layers.SpatialDropout3D(0.3)(c7)
c7 = keras.layers.Conv3D(64, (3, 3,3), activation='relu',padding='same')(c7)
#c7 = BatchNormalization()(c7)

u8 = keras.layers.Conv3DTranspose(32, (2, 2,2), strides=(2, 2,2), padding='same')(c7)
u8 = keras.layers.concatenate([u8, c2])
c8 = keras.layers.Conv3D(32, (3, 3,3), activation='relu',padding='same')(u8)
c8 = keras.layers.SpatialDropout3D(0.3)(c8)
c8 = keras.layers.Conv3D(32, (3, 3,3), activation='relu',padding='same')(c8)
#c8 = BatchNormalization()(c8)

u9 = keras.layers.Conv3DTranspose(16, (2, 2,2), strides=(2, 2,2), padding='same')(c8)
u9 = keras.layers.concatenate([u9, c1], axis=4)
c9 = keras.layers.Conv3D(16, (3, 3,3), activation='relu',padding='same')(u9)
c9 = keras.layers.SpatialDropout3D(0.3)(c9)
c9 = keras.layers.Conv3D(16, (3, 3,3), activation='relu',padding='same')(c9)
#c9 = BatchNormalization()(c9)
 
outputs = keras.layers.Conv3D(1, (1, 1,1), activation='sigmoid')(c9)
 
model = keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef,recall_metric])
model.summary()