In [None]:
import numpy as np
import keras
from PIL import Image
from IPython.display import display

# Train models

In [11]:
from keras.layers import Dense, Input
import os
from itertools import product


def get_training_data(img):
    s = img.shape[0]
    indices = np.array(list(product(range(s), range(s))))
    rgb = img[indices[:, 0], indices[:, 1]]
    pos = indices.astype(float) / s - 0.5
    rgb = rgb / 255.0

    x_data = pos
    y_data = rgb

    return x_data, y_data


def train(filename):
    print(filename)

    name = os.path.basename(filename)
    name, _ = os.path.splitext(name)
    w = 224

    img = np.array(Image.open(filename).convert("RGB").resize((w, w), Image.BILINEAR))


    x_data, y_data = get_training_data(img)

    model = keras.models.Sequential()
    model.add(Dense(20, activation='relu', input_shape=(2, ))) # 1
    model.add(Dense(20, activation='relu')) # 2
    model.add(Dense(20, activation='relu')) # 3
    model.add(Dense(20, activation='relu')) # 4
    model.add(Dense(20, activation='relu')) # 5
    model.add(Dense(20, activation='relu')) # 6
    model.add(Dense(20, activation='relu')) # 7
    model.add(Dense(20, activation='relu')) # 8
    model.add(Dense(3, activation='sigmoid')) # 9

    model.compile(optimizer='adam', loss='mse')
    
    class Callback(keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs={}):
            if epoch % 50 == 0:
                img = model.predict(x_data, batch_size=1024)
                img = img.reshape((w, w, 3))
                img = Image.fromarray((img * 255.).astype(np.uint8))
                display(img)
    

    model.fit(x_data, y_data, epochs=1000, shuffle=True, batch_size=1024, callbacks=[Callback()])
    model.save('models/{}.h5'.format(name))        

In [15]:
from glob import glob
for filename in sorted(glob('./images/*')):
    print(filename)
    train(filename)
    test(filename, [3, 7], resize=4)    

# Generate Image

In [5]:
import keras.backend as K
from keras.layers.merge import Add
from keras.layers import Dense, Lambda
from keras.models import Model
from itertools import product

def test(filename, drop_layers, resize=1):
    name = os.path.basename(filename)
    name, _ = os.path.splitext(name)
    w = 224
    
    input_pos = Input(shape=(2, ))
    input_drop = Input(shape=(9, ))

    x = input_pos

    def blend(index):
        def _blend(args):
            x, y, rate_ = args
            rate = rate_[:, index:index+1]
            return x * rate + y * (1 - rate)
        return _blend

    # skip layers
    def layer(x, input_drop, output_length, index):   
        y = Dense(output_length, activation='relu')(x)
        return Lambda(blend(index))([x, y, input_drop])

    x = Dense(20, activation='relu')(x)
    x = layer(x, input_drop, 20, 1)
    x = layer(x, input_drop, 20, 2)
    x = layer(x, input_drop, 20, 3)
    x = layer(x, input_drop, 20, 4)
    x = layer(x, input_drop, 20, 5)
    x = layer(x, input_drop, 20, 6)
    x = layer(x, input_drop, 20, 7)
    x = Dense(3, activation='sigmoid')(x)

    model = Model(inputs=[input_pos, input_drop], outputs=[x])        
    model.load_weights('models/{}_190816.h5'.format(name))

    s = w * resize
    pos = np.array(list(product(range(s), range(s))))
    pos = pos / s - 0.5
    drop_rate = np.zeros((len(pos), 9))
    
    for i in drop_layers:    
        print("drop_index:", i)
        drop_rate[:, i] = 1.0
    
    img = model.predict([pos, drop_rate], batch_size=1024)
    img = img.reshape((s, s, 3))
    img = Image.fromarray((img * 255.).astype(np.uint8))

    display(img)
    img.save('./results/{}_190816.png'.format(name))    



In [16]:
from glob import glob
for filename in sorted(glob('./images/background/*')):
    test(filename, [3, 7], resize=4)    