In [1]:
import numpy as np

import h5py
import time
import cv2

import matplotlib.pylab as plt

import keras.backend as K
from keras.utils import generic_utils
from keras.optimizers import Adam, SGD

from keras.models import Model
from keras.layers.core import Flatten, Dense, Dropout, Activation, Lambda, Reshape
from keras.layers.convolutional import Conv2D, Deconv2D, ZeroPadding2D, UpSampling2D
from keras.layers import Input, Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import MaxPooling2D
import keras.backend as K

%matplotlib inline

Using TensorFlow backend.


In [2]:
datasetpath = './output/datasetimages.hdf5'
patch_size = 32
batch_size = 16
epoch = 100

In [3]:
def normalization(X):
    return X / 127.5 - 1

def load_data(datasetpath):
    with h5py.File(datasetpath, "r") as hf:
        X_full_train = hf["train_data_raw"][:].astype(np.float32)
        X_full_train = normalization(X_full_train)
        X_sketch_train = hf["train_data_gen"][:].astype(np.float32)
        X_sketch_train = normalization(X_sketch_train)
        X_full_val = hf["val_data_raw"][:].astype(np.float32)
        X_full_val = normalization(X_full_val)
        X_sketch_val = hf["val_data_gen"][:].astype(np.float32)
        X_sketch_val = normalization(X_sketch_val)
        return X_full_train, X_sketch_train, X_full_val, X_sketch_val

In [4]:
def conv_block_unet(x, f, name, bn_axis, bn=True, strides=(2,2)):
    x = LeakyReLU(0.2)(x)
    x = Conv2D(f, (3,3), strides=strides, name=name, padding='same')(x)
    if bn: x = BatchNormalization(axis=bn_axis)(x)
    return x

def up_conv_block_unet(x, x2, f, name, bn_axis, bn=True, dropout=False):
    x = Activation('relu')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(f, (3,3), name=name, padding='same')(x)
    if bn: x = BatchNormalization(axis=bn_axis)(x)
    if dropout: x = Dropout(0.5)(x)
    x = Concatenate(axis=bn_axis)([x, x2])
    return x

In [5]:
def generator_unet_upsampling(img_shape, disc_img_shape, model_name="generator_unet_upsampling"):
    filters_num = 64
    axis_num = -1
    channels_num = img_shape[-1]
    min_s = min(img_shape[:-1])

    unet_input = Input(shape=img_shape, name="unet_input")

    conv_num = int(np.floor(np.log(min_s)/np.log(2)))
    list_filters_num = [filters_num*min(8, (2**i)) for i in range(conv_num)]

    # Encoder
    first_conv = Conv2D(list_filters_num[0], (3,3), strides=(2,2), name='unet_conv2D_1', padding='same')(unet_input)
    list_encoder = [first_conv]
    for i, f in enumerate(list_filters_num[1:]):
        name = 'unet_conv2D_' + str(i+2)
        conv = conv_block_unet(list_encoder[-1], f, name, axis_num)
        list_encoder.append(conv)

    # prepare decoder filters
    list_filters_num = list_filters_num[:-2][::-1]
    if len(list_filters_num) < conv_num-1:
        list_filters_num.append(filters_num)

    # Decoder
    first_up_conv = up_conv_block_unet(list_encoder[-1], list_encoder[-2],
                        list_filters_num[0], "unet_upconv2D_1", axis_num, dropout=True)
    list_decoder = [first_up_conv]
    for i, f in enumerate(list_filters_num[1:]):
        name = "unet_upconv2D_" + str(i+2)
        if i<2:
            d = True
        else:
            d = False
        up_conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
        list_decoder.append(up_conv)

    x = Activation('relu')(list_decoder[-1])
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(disc_img_shape[-1], (3,3), name="last_conv", padding='same')(x)
    x = Activation('tanh')(x)

    generator_unet = Model(inputs=[unet_input], outputs=[x])
    return generator_unet

In [6]:
def DCGAN_discriminator(img_shape, disc_img_shape, patch_num, model_name='DCGAN_discriminator'):
    disc_raw_img_shape = (disc_img_shape[0], disc_img_shape[1], img_shape[-1])
    list_input = [Input(shape=disc_img_shape, name='disc_input_'+str(i)) for i in range(patch_num)]
    list_raw_input = [Input(shape=disc_raw_img_shape, name='disc_raw_input_'+str(i)) for i in range(patch_num)]

    axis_num = -1
    filters_num = 64
    conv_num = int(np.floor(np.log(disc_img_shape[1])/np.log(2)))
    list_filters = [filters_num*min(8, (2**i)) for i in range(conv_num)]
    print("DCGAN_First Conv")
    generated_patch_input = Input(shape=disc_img_shape, name='discriminator_input')
    xg = Conv2D(list_filters[0], (3,3), strides=(2,2), name='disc_conv2d_1', padding='same')(generated_patch_input)
    xg = BatchNormalization(axis=axis_num)(xg)
    xg = LeakyReLU(0.2)(xg)

    print("DCGAN_First Raw Conv")
    raw_patch_input = Input(shape=disc_raw_img_shape, name='discriminator_raw_input')
    xr = Conv2D(list_filters[0], (3,3), strides=(2,2), name='raw_disc_conv2d_1', padding='same')(raw_patch_input)
    xr = BatchNormalization(axis=axis_num)(xr)
    xr = LeakyReLU(0.2)(xr)

    print("DCGAN_Next Conv")
    for i, f in enumerate(list_filters[1:]):
        name = 'disc_conv2d_' + str(i+2)
        x = Concatenate(axis=axis_num)([xg, xr])
        x = Conv2D(f, (3,3), strides=(2,2), name=name, padding='same')(x)
        x = BatchNormalization(axis=axis_num)(x)
        x = LeakyReLU(0.2)(x)
    x_flat = Flatten()(x)
    x = Dense(2, activation='softmax', name='disc_dense')(x_flat)
    
    PatchGAN = Model(inputs=[generated_patch_input, raw_patch_input], outputs=[x], name='PatchGAN')
    print("DCGAN_PATCH_COMPLETE")
    x = [PatchGAN([list_input[i], list_raw_input[i]]) for i in range(patch_num)]
    if len(x)>1:
        x = Concatenate(axis=axis_num)(x)
    else:
        x = x[0]
    x_out = Dense(2, activation='softmax', name='disc_output')(x)
    discriminator_model = Model(inputs=(list_input+list_raw_input), outputs=[x_out], name=model_name)
    print("DCGAN_END")
    return discriminator_model

In [7]:
def DCGAN(generator, discriminator, img_shape, patch_size):
    raw_input = Input(shape=img_shape, name='DCGAN_input')
    genarated_image = generator(raw_input)

    h, w = img_shape[:-1]
    ph, pw = patch_size, patch_size

    list_row_idx = [(i*ph, (i+1)*ph) for i in range(h//ph)]
    list_col_idx = [(i*pw, (i+1)*pw) for i in range(w//pw)]

    list_gen_patch = []
    list_raw_patch = []
    for row_idx in list_row_idx:
        for col_idx in list_col_idx:
            raw_patch = Lambda(lambda z: z[:, row_idx[0]:row_idx[1], col_idx[0]:col_idx[1], :])(raw_input)
            list_raw_patch.append(raw_patch)
            x_patch = Lambda(lambda z: z[:, row_idx[0]:row_idx[1], col_idx[0]:col_idx[1], :])(genarated_image)
            list_gen_patch.append(x_patch)

    DCGAN_output = discriminator(list_gen_patch+list_raw_patch)

    DCGAN = Model(inputs=[raw_input],
                  outputs=[genarated_image, DCGAN_output],
                  name='DCGAN')

    return DCGAN

In [8]:
def load_generator(img_shape, disc_img_shape):
    model = generator_unet_upsampling(img_shape, disc_img_shape)
    return model

def load_DCGAN_discriminator(img_shape, disc_img_shape, patch_num):
    model = DCGAN_discriminator(img_shape, disc_img_shape, patch_num)
    return model

def load_DCGAN(generator, discriminator, img_shape, patch_size):
    model = DCGAN(generator, discriminator, img_shape, patch_size)
    return model

In [9]:
def l1_loss(y_true, y_pred):
    return K.sum(K.abs(y_pred - y_true), axis=-1)

def inverse_normalization(X):
    return (X + 1.) / 2.

def to3d(X):
    if X.shape[-1]==3: return X
    b = X.transpose(3,1,2,0)
    c = np.array([b[0],b[0],b[0]])
    return c.transpose(3,1,2,0)

In [10]:
def plot_generated_batch(X_proc, X_raw, generator_model, batch_size, suffix):

    X_gen = generator_model.predict(X_raw)
    X_raw = inverse_normalization(X_raw)
    X_proc = inverse_normalization(X_proc)
    X_gen = inverse_normalization(X_gen)

    Xs = to3d(X_raw[:5])
    Xg = to3d(X_gen[:5])
    Xr = to3d(X_proc[:5])
    Xs = np.concatenate(Xs, axis=1)
    Xg = np.concatenate(Xg, axis=1)
    Xr = np.concatenate(Xr, axis=1)
    XX = np.concatenate((Xs,Xg,Xr),axis = 0)
    plt.imshow(XX)
    plt.axis('off')
    plt.savefig("output/batch_"+suffix+"/x_0.jpg")
    plt.clf()
    plt.close()

In [11]:
def extract_patches(X, patch_size):
    list_X = []
    list_row_idx = [(i*patch_size, (i+1)*patch_size) for i in range(X.shape[1] // patch_size)]
    list_col_idx = [(i*patch_size, (i+1)*patch_size) for i in range(X.shape[2] // patch_size)]
    for row_idx in list_row_idx:
        for col_idx in list_col_idx:
            list_X.append(X[:, row_idx[0]:row_idx[1], col_idx[0]:col_idx[1], :])
    return list_X

def get_disc_batch(procImage, rawImage, generator_model, batch_counter, patch_size):
    if batch_counter % 2 == 0:
        # produce an output
        X_disc = generator_model.predict(rawImage)
        y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)
        y_disc[:, 0] = 1
    else:
        X_disc = procImage
        y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)

    X_disc = extract_patches(X_disc, patch_size)
    return X_disc, y_disc

In [12]:
def train():
    print("load data")
    rawImage, procImage, rawImage_val, procImage_val = load_data(datasetpath)

    img_shape = rawImage.shape[-3:]
    patch_num = (img_shape[0] // patch_size) * (img_shape[1] // patch_size)
    disc_img_shape = (patch_size, patch_size, procImage.shape[-1])

    print("train")
    opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    print("load generator model")
    generator_model = load_generator(img_shape, disc_img_shape)
    print("load discriminator model")
    discriminator_model = load_DCGAN_discriminator(img_shape, disc_img_shape, patch_num)
    
    generator_model.compile(loss='mae', optimizer=opt_discriminator)
    discriminator_model.trainable = False
    
    DCGAN_model = load_DCGAN(generator_model, discriminator_model, img_shape, patch_size)
    
    loss = [l1_loss, 'binary_crossentropy']
    loss_weights = [1E1, 1]
    DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)
    
    discriminator_model.trainable = True
    discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)

    # start training
    print('start training')
    for e in range(epoch):

        starttime = time.time()
        perm = np.random.permutation(rawImage.shape[0])
        X_procImage = procImage[perm]
        X_rawImage  = rawImage[perm]
        X_procImageIter = [X_procImage[i:i+batch_size] for i in range(0, rawImage.shape[0], batch_size)]
        X_rawImageIter  = [X_rawImage[i:i+batch_size] for i in range(0, rawImage.shape[0], batch_size)]
        b_it = 0
        progbar = generic_utils.Progbar(len(X_procImageIter)*batch_size)
        for (X_proc_batch, X_raw_batch) in zip(X_procImageIter, X_rawImageIter):
            b_it += 1
            X_disc, y_disc = get_disc_batch(X_proc_batch, X_raw_batch, generator_model, b_it, patch_size)
            raw_disc, _ = get_disc_batch(X_raw_batch, X_raw_batch, generator_model, 1, patch_size)
            x_disc = X_disc + raw_disc
            # update the discriminator
            disc_loss = discriminator_model.train_on_batch(x_disc, y_disc)

            # create a batch to feed the generator model
            idx = np.random.choice(procImage.shape[0], batch_size)
            X_gen_target, X_gen = procImage[idx], rawImage[idx]
            y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
            y_gen[:, 1] = 1

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            progbar.add(batch_size, values=[
                ("D logloss", disc_loss),
                ("G tot", gen_loss[0]),
                ("G L1", gen_loss[1]),
                ("G logloss", gen_loss[2])
            ])

            # save images for visualization
            if b_it % (procImage.shape[0]//batch_size//2) == 0:
                plot_generated_batch(X_proc_batch, X_raw_batch, generator_model, batch_size, "training")
                idx = np.random.choice(procImage_val.shape[0], batch_size)
                X_gen_target, X_gen = procImage_val[idx], rawImage_val[idx]
                plot_generated_batch(X_gen_target, X_gen, generator_model, batch_size, "validation")
        print()
        print('Epoch %s/%s, Time: %s' % (e + 1, epoch, time.time() - starttime))
    print("END_TRAINING")
  
    imgs = generator_model.predict(rawImage_val)
    for i,img in enumerate(imgs):
        print(i)
        pimg = inverse_normalization(procImage_val[i])
        plt.imshow(pimg)
        plt.axis('off')
        plt.savefig("output/testData/as"+str(i)+".jpg")
        plt.clf()
        plt.close()
        
        img = inverse_normalization(img)
        plt.imshow(img)
        plt.axis('off')
        plt.savefig("output/testData/"+str(i)+".jpg")
        plt.clf()
        plt.close()
    
    t_imgs = generator_model.predict(rawImage)
    for i in range(50):
        print(i)
        img = inverse_normalization(t_imgs[i])
        plt.imshow(img)
        plt.axis('off')
        plt.savefig("output/T_Data/as"+str(i)+".jpg")
        plt.clf()
        plt.close()
        
    #save_data
    generator_model.save("output/generator_data/generator_model.h5")
    print("END_MAIN")

In [13]:
train()

load data
train
Instructions for updating:
Colocations handled automatically by placer.
load generator model
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
load discriminator model
DCGAN_First Conv
DCGAN_First Raw Conv
DCGAN_Next Conv
DCGAN_PATCH_COMPLETE
DCGAN_END
start training
Instructions for updating:
Use tf.cast instead.

Epoch 1/100, Time: 36.84327816963196

Epoch 2/100, Time: 14.925901174545288

Epoch 3/100, Time: 14.919252157211304

Epoch 4/100, Time: 14.902195930480957

Epoch 5/100, Time: 14.923876762390137

Epoch 6/100, Time: 14.940673589706421

Epoch 7/100, Time: 15.029691219329834

Epoch 8/100, Time: 14.96021294593811

Epoch 9/100, Time: 14.944020509719849

Epoch 10/100, Time: 14.95285701751709

Epoch 11/100, Time: 14.893035411834717

Epoch 12/100, Time: 15.03966498374939

Epoch 13/100, Time: 14.983184814453125

Epoch 14/100, Time: 14.975003242492676

Epoch 15/100, Time: 14.909649133682251

Epoch 16/100, T


Epoch 43/100, Time: 14.893331289291382

Epoch 44/100, Time: 14.902809858322144

Epoch 45/100, Time: 14.871660947799683

Epoch 46/100, Time: 14.931531429290771

Epoch 47/100, Time: 14.893040180206299

Epoch 48/100, Time: 14.88383674621582

Epoch 49/100, Time: 14.910606384277344

Epoch 50/100, Time: 14.949322700500488

Epoch 51/100, Time: 14.902356624603271

Epoch 52/100, Time: 14.977920532226562

Epoch 53/100, Time: 14.965080976486206

Epoch 54/100, Time: 14.949914693832397

Epoch 55/100, Time: 14.922184705734253

Epoch 56/100, Time: 14.944376468658447

Epoch 57/100, Time: 14.951767444610596

Epoch 58/100, Time: 15.126511096954346

Epoch 59/100, Time: 15.021540641784668

Epoch 60/100, Time: 15.109288215637207

Epoch 61/100, Time: 15.18982720375061

Epoch 62/100, Time: 15.044922828674316

Epoch 63/100, Time: 15.061362028121948

Epoch 64/100, Time: 15.050915479660034

Epoch 65/100, Time: 15.033265113830566

Epoch 66/100, Time: 15.137677907943726

Epoch 67/100, Time: 15.043028831481934

E


Epoch 92/100, Time: 15.174200296401978

Epoch 93/100, Time: 15.172986268997192

Epoch 94/100, Time: 15.039780378341675

Epoch 95/100, Time: 15.022215127944946

Epoch 96/100, Time: 15.220649480819702

Epoch 97/100, Time: 15.128490686416626

Epoch 98/100, Time: 15.044795513153076

Epoch 99/100, Time: 15.106876850128174

Epoch 100/100, Time: 15.048355102539062
END_TRAINING
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
END_MAIN
