In [None]:
import os
import pickle
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from keras.models import *
from keras.layers import *
from datagenerator import *
from metrics import *
from keras.callbacks import CSVLogger
np.set_printoptions(precision=3, suppress=True)

In [None]:
def build_unet(inputs, ker_init, dropout):
    conv1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv1)
    pool = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool)
    conv = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv5)
    drop5 = Dropout(dropout)(conv5)
    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2, 2))(drop5))
    merge7 = concatenate([conv3,up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv7)
    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2, 2))(conv7))
    merge8 = concatenate([conv2,up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv8)
    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2, 2))(conv8))
    merge9 = concatenate([conv,up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv9)
    up = Conv2D(32, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2, 2))(conv9))
    merge = concatenate([conv1,up], axis=3)
    conv = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge)
    conv = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv)
    conv10 = Conv2D(4, (1, 1), activation='softmax')(conv)
    return Model(inputs=inputs, outputs=conv10)

In [None]:
input_layer = Input((IMG_SIZE, IMG_SIZE, 2))
model = build_unet(input_layer, 'he_normal', 0.2)
model.compile(loss='categorical_crossentropy',
              optimizer=keras.optimizers.Adam(learning_rate=0.001),
              metrics=['accuracy',
                       tf.keras.metrics.MeanIoU(num_classes=4),
                       dice_coef,
                       precision,
                       sensitivity,
                       specificity,
                       dice_coef_necrotic,
                       dice_coef_edema,
                       dice_coef_enhancing])

In [None]:
def dir_to_ids(dir_list):
    x = []
    for i in range(0, len(dir_list)):
        x.append(dir_list[i][dir_list[i].rfind('/')+1:])
    return x

In [None]:
train_ids = dir_to_ids(os.listdir('data/train/'))
valid_ids = dir_to_ids(os.listdir('data/valid/'))
train_generator = DataGenerator('data/train/', train_ids)
valid_generator = DataGenerator('data/valid/', valid_ids)

In [None]:
csv_logger = CSVLogger('log/training.log', separator=',', append=False)
callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
                                               factor=0.2,
                                               patience=2,
                                               min_lr=0.000001,
                                               verbose=1), csv_logger]
K.clear_session()
history = model.fit(train_generator,
                     epochs=1,
                     steps_per_epoch=len(train_ids),
                     callbacks=callbacks,
                     validation_data=valid_generator)
model.save('model/unet_v1.h5')
pickle.dump(history, open('history/history.pkl', 'wb'))