In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Input, ReLU, Softmax
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.optimizers import Adam


In [3]:
import cv2
import numpy as np


In [4]:
SIZE = 256
N = 1

In [5]:
p = cv2.imread('data/bnw/1.jpeg', 0)
p.shape

(256, 256)

In [6]:
X = np.zeros((N, SIZE, SIZE))
A = np.zeros((N, SIZE, SIZE))
B = np.zeros((N, SIZE, SIZE))
for i in range(1, N + 1):
    X[i-1] = cv2.imread('data/bnw/' + str(i + 7) + '.jpeg', 0)
    A[i-1] = cv2.imread('data/a/' + str(i + 7) + '.jpeg', 0)
    B[i-1] = cv2.imread('data/b/' + str(i + 7) + '.jpeg', 0)
  


In [7]:
X = X.reshape(N, SIZE, SIZE, 1)
y = np.stack((A, B)).transpose((1, 2, 3, 0))

In [8]:
# Normalize the data
X = X / 255
y = y / 255

In [76]:
# inputs = Input((SIZE, SIZE, 1))
# # TODO: add batch normalization
# conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
# conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
# pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
#
# conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
# conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
# pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
#
# conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
# conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
# pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
#
# conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
# conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.5)(conv4)
#
# up5 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv4))
# merge5 = concatenate([conv3, up5], axis = 3)
# conv5 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge5)
# conv5 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
#
# up6 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv5))
# merge6 = concatenate([conv2, up6], axis = 3)
# conv6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
# conv6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
#
# up7 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv6))
# merge7 = concatenate([conv1, up7], axis = 3)
# conv7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
# conv7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
#
# conv8 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
# conv8 = Conv2D(2, 1, activation='sigmoid')(conv8)
#
# model = Model(inputs, conv8)
# model.summary()
#
# model.compile(optimizer=Adam(learning_rate=0.1), loss='mae')

Following are utility functions for creating a U-net Model

In [77]:
def conv_stack(input_layer, filters):
    conv1 = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(input_layer)
    batch_norm1 = BatchNormalization()(conv1)
    relu1 = ReLU()(batch_norm1)

    conv2 = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(relu1)
    batch_norm2 = BatchNormalization()(conv2)
    relu2 = ReLU()(batch_norm2)

    return relu2

In [78]:
def encoder_block(input_layer, filters):
    conv = conv_stack(input_layer, filters)
    max_pool = MaxPooling2D(pool_size=(2, 2))(conv)

    return conv, max_pool

In [79]:
def decoder_block(input_layer, skip_layer, filters):
    up = Conv2DTranspose(filters, 2, strides=2, padding='same')(input_layer)
    conc = Concatenate()([up, skip_layer])
    dec = conv_stack(conc, filters)

    return dec

In [80]:
def get_model(size, init_filters):
    inputs = Input((size, size, 1))

    conv1, max_pool1 = encoder_block(inputs, init_filters)
    conv2, max_pool2 = encoder_block(max_pool1, init_filters * 2)
    conv3, max_pool3 = encoder_block(max_pool2, init_filters * 4)

    middle_block = conv_stack(max_pool3, init_filters * 8)

    decoder1 = decoder_block(middle_block, conv3, init_filters * 4)
    decoder2 = decoder_block(decoder1, conv2, init_filters * 2)
    decoder3 = decoder_block(decoder2, conv1, init_filters)

    # soft = Softmax(axis=1)(decoder3)
    outputs = Conv2D(2, 1, padding='same', activation='sigmoid')(decoder3)

    model = Model(inputs, outputs)
    return model

In [81]:
model = get_model(SIZE, 64)
model.summary()
model.compile(optimizer=Adam(), loss='mse')

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 256, 256, 64) 640         input_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 256, 256, 64) 256         conv2d_60[0][0]                  
__________________________________________________________________________________________________
re_lu_57 (ReLU)                 (None, 256, 256, 64) 0           batch_normalization_56[0][0]     
____________________________________________________________________________________________

In [82]:
model.fit(X, y, epochs=5, verbose=1)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x2ad846811c8>

In [83]:
y_hat = model.predict(X)



In [9]:
L = X[0].reshape((SIZE, SIZE)) * 255
a = y[0, :, :, 0] * 255
b = y[0, :, :, 1] * 255
lab = np.array([L, a, b]).transpose((1, 2, 0)).astype('uint8')
img = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
cv2.imwrite('001.jpeg', img)

True

In [87]:
a_hat = y_hat[0, :, :, 0] * 255
b_hat = y_hat[0, :, :, 1] * 255
lab = np.array([L, a_hat, b_hat]).transpose((1, 2, 0)).astype('uint8')
img = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
cv2.imwrite('002.jpeg', img)

True