In [None]:
import cv2
import numpy as np
import tensorflow as tf
tf.compat.v1.disable_eager_execution() # faster

In [None]:
from tensorflow.keras.utils import Sequence
class DataGen(Sequence):
    def __init__(self, x, y, batch_size):
        self.x, self.y = x, y
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        x_batch = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        y_batch = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        # 1:1, 4:3, 16:9, 2:1
        shapes = [(512, 384), (384, 512), (512, 288), (288, 512), (512, 256), (256, 512), (256, 256), (256, 256)]
        shape = shapes[np.random.choice(len(shapes))]
        # shape = (256, 256)
        
        x_ret, y_ret = [], []
        for x_file_name, y_file_name in zip(x_batch, y_batch):
            x_img = cv2.resize(cv2.imread(x_file_name, 0), shape)
            y_img = cv2.resize(cv2.imread(y_file_name, cv2.IMREAD_UNCHANGED), shape)
            probs = y_img[:,:,3] / 255.0
            alpha_map = np.zeros((*probs.shape, 2))
            alpha_map[:,:,0] = 1 - probs
            alpha_map[:,:,1] = probs
            x_ret.append(x_img[:,:,np.newaxis])
            y_ret.append(alpha_map)

        return np.array(x_ret), np.array(y_ret)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, ReLU, Conv2D, MaxPooling2D, \
                                    Conv2DTranspose, UpSampling2D, \
                                    Add, Concatenate, BatchNormalization
def LikeUnet():
    # input
    im = Input(shape=(None, None, 1))
    pre = Conv2D(8, 3, padding='same', activation='relu')(im)
    # conv1
    a = Conv2D(16, 3, padding='same', name='conv1-1')(pre)
    b = Conv2D(16, (1, 3), padding='same', name='conv1-2')(pre)
    c = Conv2D(16, (3, 1), padding='same', name='conv1-3')(pre)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    conv1 = MaxPooling2D(2)(x)
    # conv2
    a = Conv2D(32, 3, padding='same', name='conv2-1')(conv1)
    b = Conv2D(32, (1, 3), padding='same', name='conv2-2')(conv1)
    c = Conv2D(32, (3, 1), padding='same', name='conv2-3')(conv1)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    conv2 = MaxPooling2D(2)(x)
    # conv3
    a = Conv2D(32, 3, padding='same', name='conv3-1')(conv2)
    b = Conv2D(32, (1, 3), padding='same', name='conv3-2')(conv2)
    c = Conv2D(32, (3, 1), padding='same', name='conv3-3')(conv2)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    conv3 = MaxPooling2D(2)(x)
    # conv4
    a = Conv2D(32, 3, padding='same', name='conv4-1')(conv3)
    b = Conv2D(32, (1, 3), padding='same', name='conv4-2')(conv3)
    c = Conv2D(32, (3, 1), padding='same', name='conv4-3')(conv3)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    conv4 = MaxPooling2D(2)(x)
    # conv5
    a = Conv2D(64, 3, padding='same', name='conv5-1')(conv4)
    b = Conv2D(64, (1, 3), padding='same', name='conv5-2')(conv4)
    c = Conv2D(64, (3, 1), padding='same', name='conv5-3')(conv4)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    conv5 = MaxPooling2D(2)(x)
    # transconv1
    a = Conv2D(64, 3, padding='same', name='conv6-1')(conv5)
    b = Conv2D(64, (1, 3), padding='same', name='conv6-2')(conv5)
    c = Conv2D(64, (3, 1), padding='same', name='conv6-3')(conv5)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(64, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = UpSampling2D(2, interpolation='bilinear')(x)
    # transconv2
    x = Concatenate()([conv4, x])
    a = Conv2D(32, 3, padding='same', name='conv7-1')(x)
    b = Conv2D(32, (1, 3), padding='same', name='conv7-2')(x)
    c = Conv2D(32, (3, 1), padding='same', name='conv7-3')(x)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(32, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = UpSampling2D(2, interpolation='bilinear')(x)
    # transconv3
    x = Concatenate()([conv3, x])
    a = Conv2D(32, 3, padding='same', name='conv8-1')(x)
    b = Conv2D(32, (1, 3), padding='same', name='conv8-2')(x)
    c = Conv2D(32, (3, 1), padding='same', name='conv8-3')(x)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(32, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = UpSampling2D(2, interpolation='bilinear')(x)
    # transconv4
    x = Concatenate()([conv2, x])
    a = Conv2D(32, 3, padding='same', name='conv9-1')(x)
    b = Conv2D(32, (1, 3), padding='same', name='conv9-2')(x)
    c = Conv2D(32, (3, 1), padding='same', name='conv9-3')(x)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(32, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = UpSampling2D(2, interpolation='bilinear')(x)
    # transconv5
    x = Concatenate()([conv1, x])
    a = Conv2D(16, 3, padding='same', name='conv10-1')(x)
    b = Conv2D(16, (1, 3), padding='same', name='conv10-2')(x)
    c = Conv2D(16, (3, 1), padding='same', name='conv10-3')(x)
    x = Add()([a, b, c])
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(16, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = UpSampling2D(2, interpolation='bilinear')(x)
    # output
    x = Conv2D(8, 3, padding='same', activation='relu')(x)
    out = Conv2D(2, 1, padding='same', activation='softmax')(x)
    return Model(inputs=im, outputs=out)

In [None]:
with open("x_train.txt", "r") as f:
    x_train = [filename[:-1] for filename in f.readlines()]
with open("y_train.txt", "r") as f:
    y_train = [filename[:-1] for filename in f.readlines()]
with open("x_val.txt", "r") as f:
    x_val = [filename[:-1] for filename in f.readlines()]
with open("y_val.txt", "r") as f:
    y_val = [filename[:-1] for filename in f.readlines()]

In [None]:
trainGen = DataGen(x_train, y_train, batch_size=60)
valGen = DataGen(x_val, y_val, batch_size=60)

In [None]:
from tensorflow.keras.metrics import MeanIoU
class mIoU(MeanIoU):
    def __init__(self, num_classes):
        super().__init__(num_classes=num_classes, name="mIoU")
    
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model, model_from_json
model = LikeUnet()
with open('dark-pre.json', 'w') as f:
    f.write(model.to_json())
# with open('dark-micro.json', 'r') as f:
#     json_string = f.read()
# model = model_from_json(json_string)
# model.load_weights('models\\dark-20-0.9766.h5')
model.compile(optimizer=Adam(amsgrad=True),
              loss="binary_crossentropy",
              metrics=[mIoU(num_classes=2)])

In [None]:
from datetime import datetime
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
logdir = "logs\\dark-pre\\" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=logdir, profile_batch=0)
filepath = "models\\dark-pre-{epoch:02d}-{val_mIoU:.4f}.h5"
checkpoint_callback = ModelCheckpoint(filepath, monitor='val_mIoU', save_weights_only=True)

In [None]:
# Small dataset for debugging
# model.fit_generator(generator=valGen, epochs=10, workers=8, shuffle=True)
# model.save_weights('test.h5')

In [None]:
model.fit_generator(generator=trainGen, validation_data=valGen, epochs=20, workers=8, shuffle=True,
                    callbacks=[tensorboard_callback, checkpoint_callback])