In [6]:
import numpy as np
import cv2

img_rows = 96
img_cols = 96

delta_x = int(img_rows/2)
delta_y = int(img_cols/2)

In [7]:
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
from keras.optimizers import Adam

def get_unet():
    inputs = Input((img_rows, img_cols, 3))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate(inputs=[Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate(inputs=[Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate(inputs=[Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate(inputs=[Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(lr=1e-5), loss='mse')

    return model

In [8]:
def preprocess(image):
    image = cv2.resize(image, None, fx=0.2, fy=0.2, interpolation=cv2.INTER_LINEAR)
    x_center = int(image.shape[0] / 2)
    y_center = int(image.shape[1] / 2)
    image = image[x_center - delta_x: x_center + delta_x,
                y_center - delta_y: y_center + delta_y]
    return image


def read_rgb_image(filename):
    image = cv2.imread(filename)
    if image is None:
        print("None image = " + str(filename))
    image = preprocess(image)
    image = np.true_divide(image, 127.5)
    image -= 1
    return image


def read_depth_image(filename):
    image = cv2.imread(filename)
    if image is None:
        print("None image = " + str(filename))
    image = preprocess(image)
    normal_depth = image[:, :, 0]
    normal_depth = normal_depth[:, :, np.newaxis]
    normal_depth = np.true_divide(normal_depth, 127.5)
    normal_depth -= 1
    return normal_depth


def read_sample(img_filenames):
    image = read_rgb_image(img_filenames["image"])
    depth = read_depth_image(img_filenames["depth"])
    return image, depth

In [9]:
def image_generator(data, read_sample, shuffle=False):
    if shuffle:
        np.random.shuffle(data)
    for img_filenames in data:
        img_real, depth_real = read_sample(img_filenames)
        yield img_real, depth_real


def batch_generator(img_generator, batch_size=32):
    while True:
        cur_batch_x = []
        cur_batch_y = []
        img_gen = img_generator()
        for image, depth in img_gen:
            cur_batch_x.append(image)
            cur_batch_y.append(depth)
            if len(cur_batch_x) == batch_size:
                yield (np.array(cur_batch_x), np.array(cur_batch_y))
                cur_batch_x = []
                cur_batch_y = []

In [10]:
from data import load_not_none, load_test_not_none
from keras.callbacks import ModelCheckpoint

def train():
    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)
    train_data = load_not_none()
    test_data = load_test_not_none()

    img_generator = lambda: image_generator(train_data, read_sample, shuffle=True)
    train_generator = batch_generator(img_generator, 8)
    
    test_img_gen = lambda: image_generator(test_data, read_sample, shuffle=True)
    test_generator = batch_generator(test_img_gen, 8)

    print('-' * 30)
    print('Creating and compiling model...')
    print('-' * 30)
    model = get_unet()
    print(str(model.summary()))

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)
    
    model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True)
    
    model.fit_generator(generator=train_generator, steps_per_epoch=11300, verbose=1, epochs=10, 
                        callbacks=[model_checkpoint], validation_data=test_generator, validation_steps=5336)
    model.save(model_filename)
    
    
train()

------------------------------
Loading and preprocessing train data...
------------------------------
------------------------------
Creating and compiling model...
------------------------------
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_2 (InputLayer)             (None, 96, 96, 3)     0                                            
____________________________________________________________________________________________________
conv2d_20 (Conv2D)               (None, 96, 96, 32)    896         input_2[0][0]                    
____________________________________________________________________________________________________
conv2d_21 (Conv2D)               (None, 96, 96, 32)    9248        conv2d_20[0][0]                  
_________________________________________________________________________________________________

KeyboardInterrupt: 