In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Input, ReLU, Softmax
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint


In [2]:
import cv2
import numpy as np


In [3]:
SIZE = 256
N = 1000

In [4]:
X = np.zeros((N, SIZE, SIZE))
A = np.zeros((N, SIZE, SIZE))
B = np.zeros((N, SIZE, SIZE))
for i in range(1, N + 1):
    X[i-1] = cv2.imread('data/bnw/' + str(i + 7) + '.jpeg', 0)
    A[i-1] = cv2.imread('data/a/' + str(i + 7) + '.jpeg', 0)
    B[i-1] = cv2.imread('data/b/' + str(i + 7) + '.jpeg', 0)



In [16]:
bins_list = [0, 64, 96, 112, 120, 128, 136, 144, 160, 192, 256]
num_bins = len(bins_list) - 1
bins = np.array(bins_list)
a_binned = np.digitize(A, bins) - 1
b_binned = np.digitize(B, bins) - 1
y_binned = a_binned * num_bins + b_binned
print(y_binned.shape)

(1000, 256, 256)


In [6]:
X = X.reshape(N, SIZE, SIZE, 1)
y = np.stack((A, B)).transpose((1, 2, 3, 0))

In [7]:
# Normalize the data
X = X / 255
y = y / 255

In [8]:
# inputs = Input((SIZE, SIZE, 1))
# # TODO: add batch normalization
# conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
# conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
# pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
#
# conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
# conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
# pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
#
# conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
# conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
# pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
#
# conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
# conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.5)(conv4)
#
# up5 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv4))
# merge5 = concatenate([conv3, up5], axis = 3)
# conv5 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge5)
# conv5 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
#
# up6 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv5))
# merge6 = concatenate([conv2, up6], axis = 3)
# conv6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
# conv6 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
#
# up7 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size = (2,2))(conv6))
# merge7 = concatenate([conv1, up7], axis = 3)
# conv7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
# conv7 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
#
# conv8 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
# conv8 = Conv2D(2, 1, activation='sigmoid')(conv8)
#
# model = Model(inputs, conv8)
# model.summary()
#
# model.compile(optimizer=Adam(learning_rate=0.1), loss='mae')

Following are utility functions for creating a U-net Model

In [9]:
def conv_stack(input_layer, filters):
    conv1 = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(input_layer)
    batch_norm1 = BatchNormalization()(conv1)
    relu1 = ReLU()(batch_norm1)

    conv2 = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(relu1)
    batch_norm2 = BatchNormalization()(conv2)
    relu2 = ReLU()(batch_norm2)

    return relu2

In [10]:
def encoder_block(input_layer, filters):
    conv = conv_stack(input_layer, filters)
    max_pool = MaxPooling2D(pool_size=(2, 2))(conv)

    return conv, max_pool

In [11]:
def decoder_block(input_layer, skip_layer, filters):
    up = Conv2DTranspose(filters, 2, strides=2, padding='same')(input_layer)
    conc = Concatenate()([up, skip_layer])
    dec = conv_stack(conc, filters)

    return dec

In [12]:
def get_model(size, init_filters):
    inputs = Input((size, size, 1))

    conv1, max_pool1 = encoder_block(inputs, init_filters)
    conv2, max_pool2 = encoder_block(max_pool1, init_filters * 2)
    conv3, max_pool3 = encoder_block(max_pool2, init_filters * 4)

    middle_block = conv_stack(max_pool3, init_filters * 8)

    decoder1 = decoder_block(middle_block, conv3, init_filters * 4)
    decoder2 = decoder_block(decoder1, conv2, init_filters * 2)
    decoder3 = decoder_block(decoder2, conv1, init_filters)

    # soft = Softmax(axis=1)(decoder3)
    output_a = Conv2D(num_bins, 1, padding='same', activation='softmax')(decoder3)
    output_b = Conv2D(num_bins, 1, padding='same', activation='softmax')(decoder3)

    model = Model(inputs, [output_a, output_b])
    return model

In [13]:
model = get_model(SIZE, 64)
model.summary()
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  640         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 256, 256, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

In [14]:
callbacks = [
    ModelCheckpoint("aero-color2.h5", save_best_only=True, save_weights_only=True, monitor='loss')
]

weight_dict_a = {0: 0.0003617706298828125, 1: 0.009904220581054687, 2: 0.046289077758789064, 3: 0.05516780090332031, 4: 0.2594354248046875,
                 5: 0.4471910400390625, 6: 0.09607235717773438, 7: 0.061067153930664066, 8: 0.02186639404296875, 9: 0.0026447601318359375}

weight_dict_b = {0: 0.0003795623779296875, 1: 0.0158834228515625, 2: 0.03734930419921875, 3: 0.04284243774414063, 4: 0.1585091552734375,
                 5: 0.31331358337402343, 6: 0.15075779724121094, 7: 0.17564840698242187, 8: 0.0960227813720703, 9: 0.009293548583984375}

In [15]:
model.fit(X, [a_binned, b_binned], epochs=5, verbose=1, batch_size=4, callbacks=callbacks, class_weight=[weight_dict_a, weight_dict_b])

AttributeError: 'list' object has no attribute 'keys'

In [None]:
a_hats, b_hats = model.predict(X[0].reshape(1, SIZE, SIZE, 1))

In [None]:
a_hats = np.argmax(a_hats, axis=3)
b_hats = np.argmax(b_hats, axis=3)

In [None]:
# Real Image
L = X[0].reshape((SIZE, SIZE)) * 255
a = y[0, :, :, 0] * 255
b = y[0, :, :, 1] * 255
lab = np.array([L, a, b]).transpose((1, 2, 0)).astype('uint8')
img = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
cv2.imwrite('001.jpeg', img)

True

In [None]:
# Real Image Binned
L = X[0].reshape((SIZE, SIZE)) * 255
a2 = y_binned[0, :, :, 0] * bin_size + bin_size // 2
b2 = y_binned[0, :, :, 1] * bin_size + bin_size // 2
lab = np.array([L, a2, b2]).transpose((1, 2, 0)).astype('uint8')
img = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
cv2.imwrite('002.jpeg', img)

True

In [None]:
L = X[0].reshape((SIZE, SIZE)) * 255
a_hat = a_hats[0, :, :]
b_hat = b_hats[0, :, :]
a_val = bins[a_hat]
b_hat = bins[b_hat]
lab = np.array([L, a_hat, b_hat]).transpose((1, 2, 0)).astype('uint8')
img = cv2.cvtColor(lab, cv2.COLOR_Lab2BGR)
cv2.imwrite('003.jpeg', img)

True

In [None]:
cv2.imwrite('a.jpeg', a_hat)
cv2.imwrite('b.jpeg', b_hat)

True

In [None]:
# unique, counts = np.unique(a_binned, return_counts=True)
# counts = counts / counts.sum()
# weights_a = counts.sum() / counts / num_bins
# print(dict(zip(unique, weights_a)))

# unique, counts = np.unique(b_binned, return_counts=True)
# counts = counts / counts.sum()
# weights_b = counts.sum() / counts / num_bins
# print(dict(zip(unique, weights_b)))


{0: 0.0003617706298828125, 1: 0.009904220581054687, 2: 0.046289077758789064, 3: 0.05516780090332031, 4: 0.2594354248046875, 5: 0.4471910400390625, 6: 0.09607235717773438, 7: 0.061067153930664066, 8: 0.02186639404296875, 9: 0.0026447601318359375}
{0: 0.0003795623779296875, 1: 0.0158834228515625, 2: 0.03734930419921875, 3: 0.04284243774414063, 4: 0.1585091552734375, 5: 0.31331358337402343, 6: 0.15075779724121094, 7: 0.17564840698242187, 8: 0.0960227813720703, 9: 0.009293548583984375}


In [None]:
# # bins_list = [0, 64, 96, 112, 120, 128, 136, 144, 160, 192, 256]
# # num_bins = len(bins_list) - 1
# bins = np.arange(17) * 16
# a_binned = np.digitize(A, bins) - 1
# b_binned = np.digitize(B, bins) - 1
# y_binned = np.stack((a_binned, b_binned)).transpose((1, 2, 3, 0))

# unique, counts = np.unique(a_binned, return_counts=True)
# counts = counts / counts.sum()
# # weights_a = counts.sum() / counts / num_bins
# print(dict(zip(unique, counts)))

# unique, counts = np.unique(b_binned, return_counts=True)
# counts = counts / counts.sum()
# # weights_b = counts.sum() / counts / num_bins
# print(dict(zip(unique, counts)))


{2: 0.0002000732421875, 3: 0.0001616973876953125, 4: 0.001061920166015625, 5: 0.008842300415039063, 6: 0.046289077758789064, 7: 0.3146032257080078, 8: 0.5432633972167968, 9: 0.061067153930664066, 10: 0.015013397216796875, 11: 0.006852996826171875, 12: 0.0026249542236328124, 13: 1.9805908203125e-05}
{2: 1.9989013671875e-06, 3: 0.0003775634765625, 4: 0.0032924346923828124, 5: 0.012590988159179688, 6: 0.03734930419921875, 7: 0.20135159301757813, 8: 0.4640713806152344, 9: 0.17564840698242187, 10: 0.07318998718261718, 11: 0.022832794189453123, 12: 0.008207687377929688, 13: 0.001085845947265625, 14: 1.52587890625e-08}
