In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, UpSampling2D
from tensorflow.python.keras.optimizers import Adadelta, Nadam
from tensorflow.python.keras.models import Model, load_model
from tensorflow.python.keras.utils import multi_gpu_model, plot_model
from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.python.keras.preprocessing import image
import tensorflow as tf
from tensorflow.python.keras.losses import binary_crossentropy
from model import Unet

In [None]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

def total_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [None]:
images = np.load('images.npy')
masks = np.load('masks.npy')
images /= 255.0

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1,2,1)
ax.set_title(f'Image')
ax.imshow(images[5335])

ax1 = fig.add_subplot(1,2,2)
ax1.set_title(f'Mask')
ax1.imshow(masks[5335][:,:,0])

In [None]:
unet = Unet(256, 64)
p_unet = multi_gpu_model(unet, 4)
p_unet.compile(optimizer='adam', loss=dice_loss, metrics=[dice_coeff, 'accuracy'])
tb = TensorBoard(log_dir='logs', write_graph=True)
mc = ModelCheckpoint(filepath='models/top_weights.h5', monitor='acc', save_best_only='True', save_weights_only='True', verbose=1)
es = EarlyStopping(monitor='loss', patience=15, verbose=1)
rlr = ReduceLROnPlateau(monitor='loss')
callbacks = [tb, mc, es, rlr]

In [None]:
history = p_unet.fit(images, masks, batch_size=32, epochs=40, callbacks=callbacks)

In [None]:
unet.save_weights('unet_dice40_64.h5')