In [1]:
import os
import numpy as np

from skimage.transform import resize
from skimage.io import imsave

from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [20]:
K.set_image_data_format('channels_last')  # TF dimension ordering in this code

img_rows = 288
img_cols = 288

smooth = 1.

In [21]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)


def Specificity(y_true, y_pred):
    true_negatives = K.abs(y_pred)- K.abs(y_true)
    return ((true_negatives+smooth)/(y_pred+ smooth))

def Sensitivity(y_true, y_pred):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    return ((y_pred+smooth)/ (y_true+smooth))

def Jaccard_index(y_true,y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return ((intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + smooth))

In [22]:
def get_unet():
    inputs = Input((img_rows, img_cols, 1))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])
    #keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    #1e-5

    model.compile(optimizer=Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False), loss=dice_coef_loss, metrics=[dice_coef, 'acc',Jaccard_index, Specificity, Sensitivity])

    return model

In [23]:
def preprocess(imgs):
    imgs_p = np.ndarray((imgs.shape[0], img_rows, img_cols, 1), dtype=np.uint8)
    for i in range(imgs.shape[0]):
        imgs_p[i] = resize(imgs[i], (img_cols, img_rows), preserve_range=True)

#     imgs_p = imgs_p[..., np.newaxis]

    return imgs_p

In [24]:
def load_train_data():
    imgs_train = np.load('npy-with/train-images.npy')
    imgs_mask_train = np.load('npy-with/train-masks.npy')
    return imgs_train, imgs_mask_train

def load_validation_data():
    imgs_validation = np.load('npy-with/validation-images.npy')
    imgs_mask_validation = np.load('npy-with/validation-masks.npy')
    return imgs_validation, imgs_mask_validation

# def load_test_data():
#     imgs_test = np.load('npy-with/test-images.npy')
#     imgs_id = np.load('npy-with/test-masks.npy')
#     return imgs_test, imgs_id

In [25]:
def train_and_predict():
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_mask_train = load_train_data()
    print(len(imgs_train))
    imgs_train = preprocess(imgs_train)
    imgs_mask_train = preprocess(imgs_mask_train)

    imgs_train = imgs_train.astype('float32')
    mean = np.mean(imgs_train)  # mean for data centering
    std = np.std(imgs_train)  # std for data normalization

    imgs_train -= mean
    imgs_train /= std

    imgs_mask_train = imgs_mask_train.astype('float32')
    imgs_mask_train /= 255.  # scale masks to [0, 1]

    
    print('-'*30)
    print('Loading and preprocessing validation data...')
    print('-'*30)
    imgs_validation, imgs_mask_validation = load_validation_data()
    print(len(imgs_validation))
    imgs_validation = preprocess(imgs_validation)
    imgs_mask_validation = preprocess(imgs_mask_validation)

    imgs_validation = imgs_validation.astype('float32')
    mean = np.mean(imgs_validation)  # mean for data centering
    std = np.std(imgs_validation)  # std for data normalization

    imgs_validation -= mean
    imgs_validation /= std

    imgs_mask_validation = imgs_mask_validation.astype('float32')
    imgs_mask_validation /= 255.  # scale masks to [0, 1]
    
    
    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

    model = get_unet()
    model_checkpoint = ModelCheckpoint('weights-unet.h5', monitor='val_loss', save_best_only=True)

    print('-'*30)
    print('Fitting model...')
    print('-'*30)
    history = model.fit(imgs_train, imgs_mask_train, batch_size=10, nb_epoch=50, verbose=1, shuffle=True,
              validation_data=(imgs_validation, imgs_mask_validation),
              callbacks=[model_checkpoint])

#     print('-'*30)
#     print('Loading and preprocessing test data...')
#     print('-'*30)
#     imgs_test, imgs_id_test = load_test_data()
#     imgs_test = preprocess(imgs_test)

#     imgs_test = imgs_test.astype('float32')
#     imgs_test -= mean
#     imgs_test /= std

#     print('-'*30)
#     print('Loading saved weights...')
#     print('-'*30)
#     model.load_weights('weights.h5')

#     print('-'*30)
#     print('Predicting masks on test data...')
#     print('-'*30)
#     imgs_mask_test = model.predict(imgs_test, verbose=1)
#     np.save('imgs_mask_test.npy', imgs_mask_test)

#     print('-' * 30)
#     print('Saving predicted masks to files...')
#     print('-' * 30)
#     pred_dir = 'prediction'
#     if not os.path.exists(pred_dir):
#         os.mkdir(pred_dir)
#     for image, image_id in zip(imgs_mask_test, imgs_id_test):
#         image = (image[:, :, 0] * 255.).astype(np.uint8)
#         imsave(os.path.join(pred_dir, str(image_id) + '_pred.jpg'), image)

In [26]:
train_and_predict()

------------------------------
Loading and preprocessing train data...
------------------------------
2000


  warn("The default mode, 'constant', will be changed to 'reflect' in "


------------------------------
Loading and preprocessing validation data...
------------------------------
150
------------------------------
Creating and compiling model...
------------------------------
------------------------------
Fitting model...
------------------------------




Train on 2000 samples, validate on 150 samples
Epoch 1/50
  10/2000 [..............................] - ETA: 32:24 - loss: -0.0014 - dice_coef: 0.0014 - acc: 0.0176 - Jaccard_index: 6.8487e-04 - Specificity: 0.9998 - Sensitivity: 1.5110

KeyboardInterrupt: 