In [1]:
import os
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

import cv2
from tqdm import tqdm_notebook, tnrange
from glob import glob
from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label


from sklearn.model_selection import train_test_split
from skimage.color import rgb2gray
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.models import Model, load_model, save_model
from tensorflow.keras.layers import Input, Activation, BatchNormalization, Dropout, Lambda, Conv2D, Conv2DTranspose, MaxPooling2D, Concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint



In [18]:
# Set image dimensions
im_width = 256
im_height = 256

# Load FLAIR mask and MRI images from dataset
mask_files = glob("lgg-mri-segmentation/kaggle_3m/*/*_mask*")
mri_files = []

for file in mask_files:
    mri_files.append(file.replace('_mask', ''))

['lgg-mri-segmentation/kaggle_3m/TCGA_CS_6667_20011105/TCGA_CS_6667_20011105_8.tif', 'lgg-mri-segmentation/kaggle_3m/TCGA_CS_6667_20011105/TCGA_CS_6667_20011105_9.tif', 'lgg-mri-segmentation/kaggle_3m/TCGA_CS_6667_20011105/TCGA_CS_6667_20011105_2.tif', 'lgg-mri-segmentation/kaggle_3m/TCGA_CS_6667_20011105/TCGA_CS_6667_20011105_3.tif', 'lgg-mri-segmentation/kaggle_3m/TCGA_CS_6667_20011105/TCGA_CS_6667_20011105_20.tif']


In [29]:
# Plot MRI images with FLAIR mask filter
def plot_masked_mri(rows, cols, mri_path_list, mask_path_list):
    fig = plt.figure(figsize=(12, 12))
    for i in range(1, rows * cols + 1):
        fig.add_subplot(rows, cols, i)
        # Show images of corresponding mask/MRI from file path
        mri_path = mri_path_list[i]
        mask_path = mask_path_list[i]
        mri = cv2.imread(mri_path)
        mri = cv2.cvtColor(mri, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        plt.imshow(mri)
        plt.imshow(mask, alpha = 0.4)
    plt.show()

# 3x3 grid
plot_masked_mri(3, 3, mri_files, mask_files)

In [31]:
# Create DataFrame
im_data = pd.DataFrame(data = {'mri_files': mri_files, 'mask_files': mask_files})

# Split data into train and test
im_train, im_test = train_test_split(im_data, test_size = 0.1)

# Split test into validation and test
im_val, im_test = train_test_split(im_test, test_size = 0.2)


In [None]:
def train_generator(df, batch_size, aug_dict, mri_color='rgb', mask_color='grayscale', mri_save='image', mask_save='mask', save_to=None, target_size=(256, 256), seed=1):
    mri_data_generator = ImageDataGenerator(**aug_dict)
    mask_data_generator = ImageDataGenerator(**aug_dict)

    mri_gen = mri_data_generator.flow_from_dataframe(
        df,
        x_col = "mri_files",
        class_mode = None,
        color_mode = mri_color,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to,
        save_prefix = mri_save,
        seed = seed
    )

    mask_gen = mask_data_generator.flow_from_dataframe(
        df,
        x_col = "mask_files",
        class_mode = None,
        color_mode = mask_color,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to,
        save_prefix = mask_save,
        seed = seed
    )

    train_gen = zip(mri_gen, mask_gen)

    for (mri, mask) in train_gen:
        mri, mask = normalized_diagnosis(mri, mask)
        yield (mri, mask)

In [None]:
def normalized_diagnosis(mri, mask):
    mri /= 255
    mask /= 255
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0
    return (mri, mask)

In [None]:
def dice_coefficient(actual, pred, smooth_factor):
    actual_flatten = K.flatten(actual)
    pred_flatten = K.flatten(pred)
    
    intersect = K.sum(actual_flatten, pred_flatten)
    union = K.sum(actual_flatten) + K.sum(pred_flatten)

    return (2 * intersect + smooth_factor) / (union + smooth_factor)

def dice_loss(actual, pred):
    return -dice_coefficient(actual, pred)

def intersect_over_union(actual, pred, smooth_factor):
    intersect = K.sum(actual * pred)
    union = K.sum(actual + pred)
    return (intersect + smooth_factor) / (union - intersect + smooth_factor)


def jaccard(actual, pred):
    actual_flatten = K.flatten(actual)
    pred_flatten = K.flatten(pred)
    return -intersect_over_union(actual_flatten, pred_flatten)

In [36]:
def encoder(inputs, filters):
    conv1 = Conv2D(filters, 3, padding='same')(inputs)
    act1 = Activation('relu')(conv1)
    conv1 = Conv2D(filters, 3, padding='same')(act1)
    act1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2), strides=2)(act1)

    return pool1

In [35]:
def decoder(inputs, skip_features, filters):
    conv1 = Conv2DTranspose(filters, (2, 2), strides=2, padding='same')(inputs)
    skip_features = tf.image.resize(skip_features, size=(conv1.shape[1], conv1.shape[2]))
    up1 = Concatenate()([conv1, skip_features])
    conv1 = Conv2D(filters, 3, padding='same')(up1)
    act1 = Activation('relu')(conv1)
    conv1 = Conv2D(filters, 3, padding='valid')(act1)
    act1 = Activation('relu')(conv1)

    return act1

In [37]:
def unet_model(input_size = (256, 256, 3)):
    inp = Input(input_size)

    cont1 = encoder(inp, 64)
    cont2 = encoder(cont1, 128)
    cont3 = encoder(cont2, 256)
    cont4 = encoder(cont3, 512)

    conv_bn = Conv2D(1024, 3, padding='same')(cont4)
    act_bn = Activation('relu')(conv_bn)
    conv_bn = Conv2D(1024, 3, padding='same')(act_bn)
    act_bn = Activation('relu')(conv_bn)

    exp1 = decoder(act_bn, cont4, 512)
    exp2 = decoder(exp1, cont3, 256)
    exp3 =  decoder(exp2, cont2, 128)
    exp4 = decoder(exp3, cont3, 64)

    out = Conv2D(1, 1, padding='same', activation='sigmoid')(exp4)
    model = Model(inputs = inp, outputs=out, name='U-Net')

    return model


In [None]:
train_generator_param = dict(
    rotation_range=0.2,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.05,
    horizontal_flip=True,
    fill_mode='nearest'
)

train_gen = train_generator(im_train, 32, train_generator_param, target_size=(im_height, im_width))
test_gen = train_generator(im_test, 32, dict(), target_size=(im_height, im_width))