In [10]:
from keras.layers import (
    Dense,
    Activation,
    BatchNormalization,
    Conv2D,
    Conv2DTranspose,
    Input,
    Add,
    Reshape,
    Concatenate,
    MaxPool1D,
    Dot,
    GlobalAveragePooling2D,
    Embedding
)
from keras.models import (
    Sequential, Model
)
from keras.activations import (
    softmax,
    tanh,
    sigmoid,
    relu
)
from keras import backend as K
from keras import optimizers
from keras.utils.generic_utils import get_custom_objects
from keras.backend.tensorflow_backend import set_session
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from collections import OrderedDict
import time
from utils.utils import GroupNormalization

%matplotlib inline

In [11]:
INPUT_SHAPE = (128, 128, 3)
BATCH_SIZE = 1
DIS_LR = 0.0002
GEN_LR = 0.0002
DECAY = .99
X_PATH = './datasets/monet2photo/trainA'
Y_PATH = './datasets/monet2photo/trainB'

# Upsampling parameters
NUM_CONV_LAYERS = 2
INIT_FILTER = 16
INIT_LENGTH = 64
KERNEL_SIZE = (3, 3)
CONV_STRIDES = (2, 2)

# Residual Block parameters
NUM_REPETITIONS = 2
NUM_RES_BLOCKS = 2
RES_STRIDES = (1, 1)

# # WGAN parameters
GRADIENT_PENALTY_WEIGHT = 10
TRAINING_RATIO = 5  # the number of discriminator updates per generator update

In [12]:
def leakyReLu(input_):
    return relu(input_, alpha=.2)

get_custom_objects().update({'leakyReLu': Activation(leakyReLu)})


""" 
Plot the images inline using matplotlib.

This function takes in 8 pictures where 

X -> Original Picture X 
Y -> Original Picture Y
Y' -> generator_xy(X)
X' -> generator_yx(Y)
Y'' -> generator_xy(X')
X'' -> generator_yx(Y')
X_identical -> generator_yx(X)
Y_identical -> generator_xy(Y)

"""
def plot_images(*images):
    assert(len(images) == 8)
    image_names = ['X', 'Y', "Y'", "X'", "X''", "Y''", 'X_identical', 'Y_identical']
    
    # plt.rcParams['figure.figsize'] = [10, 10]
    
    for i, name in enumerate(image_names):
        image = ((images[i] + 1) / 2 * 255.).astype(int)
        plt.subplot(4, 2, i+1)
        plt.imshow(image)
        plt.title(name)
    
    plt.show()

In [27]:
class CycleGAN:
    def __init__(self):
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.session = tf.Session(config=config)
        set_session(self.session)
        
        '''
        with tf.device('/device:GPU:0'):
            self.generator_xy = self.generator()
            self.generator_yx = self.generator()
        
        with tf.device('/device:GPU:1'):
            self.discriminator_x = self.discriminator()
            self.discriminator_y = self.discriminator()
        
            self.xy_Dataset = self.buildDataset()
        
            X, Y     = Input(INPUT_SHAPE)   , Input(INPUT_SHAPE)
            X_, Y_   = self.generator_yx(Y) , self.generator_xy(X)
            X__, Y__ = self.generator_yx(Y_), self.generator_xy(X_)
            X_identity, Y_identity = self.generator_yx(X), self.generator_xy(Y)

            adam_dis = optimizers.Adam(lr=DIS_LR)
            adam_gen = optimizers.Adam(lr=GEN_LR)
        
            self.discriminator_x.compile(loss='mse', optimizer=adam_dis, metrics=['accuracy'])
            self.discriminator_y.compile(loss='mse', optimizer=adam_dis, metrics=['accuracy'])

            self.discriminator_x.trainable = False
            self.discriminator_y.trainable = False

            X_valid, Y_valid = self.discriminator_x(X_), self.discriminator_y(Y_)
        
        with tf.device('/device:GPU:0'):
            # TODO: Figure out the weights of the losses
            self.generators = Model(
                    inputs=[X, Y], 
                    outputs=[X_valid, Y_valid, X_, Y_, X__, Y__, X_identity, Y_identity]
                )

            # The paper suggests using L1 norm for the last four loss functions, try out different settings if it doesn't work
            self.generators.compile(
                loss=['mse']*8,
                loss_weights=[1, 1, 0, 0, 10, 10, 1, 1],
                optimizer=adam_gen
            )
        '''

    def buildDataset(self, x_path = X_PATH, y_path = Y_PATH):        
        x_Dataset = tf.data.Dataset.list_files( x_path + '/*.jpg')
        y_Dataset = tf.data.Dataset.list_files( y_path + '/*.jpg')

        x_images = x_Dataset.map(lambda x: tf.image.resize_images(tf.image.decode_jpeg(tf.read_file(x), channels = INPUT_SHAPE[2]), [INPUT_SHAPE[0], INPUT_SHAPE[1]]))
        y_images = y_Dataset.map(lambda x: tf.image.resize_images(tf.image.decode_jpeg(tf.read_file(x), channels = INPUT_SHAPE[2]), [INPUT_SHAPE[0], INPUT_SHAPE[1]]))

        xy_images = tf.data.Dataset.zip((x_images, y_images))
        xy_Dataset = xy_images.batch(BATCH_SIZE)
        return xy_Dataset

    def discriminator(self, x, y):
        # X is the image waiting to be judged
        # Y is the ground truth as a numpy array
        # Batch Size = None, Height, Width, Channels
        _, init_len, _, _ = K.int_shape(x)
        init_filters = INIT_FILTER
        x_feature = self.conv2d_layer(x)
        y_feature = self.conv2d_layer(y)

        # linear product of curr and y
        linearProductOutput = Dot(axes = 1)([x_feature, y_feature])
        linearOutput = Dense(init_filters*8)(x_feature)
        
        curr = Add()([linearOutput, linearProductOutput])
        return Model(inputs=[x, y], outputs=[curr])
    
    def conv2d_layer(self, x):
        init_filters = INIT_FILTER
        # 6 layers of resblock
        curr = self.d_resblock(x, init_filters)
        curr = self.d_resblock(curr, init_filters*2)
        curr = self.d_resblock(curr, init_filters*4)
        curr = self.d_resblock(curr, init_filters*8)
        curr = self.d_resblock(curr, init_filters*8, False)
        curr = Activation('relu')(curr)
        
        # apply global sum pooling
#         _, height, _, _ = K.int_shape(curr)
#         curr = GlobalAveragePooling2D(curr) # 2D tensor with shape: (batch_size, channels)
        curr = K.sum(curr, axis=(1,2))
        return curr
        
        
    
    def d_resblock(self, x, filters, size_change = True):
        # resblock for the descriminator
        # if the input and output shape are different, set size_change to true
        curr = GroupNormalization()(x) # ? axis=?
        curr = Activation('relu')(curr)
        
        if size_change:
            temp_x = Conv2D(filters, KERNEL_SIZE, strides=CONV_STRIDES, padding='same', kernel_initializer='truncated_normal')(curr)
        else:
            temp_x = curr
        
        curr = GroupNormalization()(temp_x)
        curr = Activation('relu')(curr)
        curr = Conv2D(filters, KERNEL_SIZE, strides=(1,1), padding='same', kernel_initializer='truncated_normal')(curr)
        curr = GroupNormalization()(curr)
        curr = Activation('relu')(curr)
        curr = Conv2D(filters, KERNEL_SIZE, strides=(1,1), padding='same', kernel_initializer='truncated_normal')(curr)
        return  Add()([temp_x, curr])
    
    # X is the output from last layer
    # Y is the ground truth as a numpy array
    def generator(self, x, y_dict):
        _, init_len, _, _ = K.int_shape(x)
        init_filters = INIT_FILTER
        
        curr = self.conv_res(x, y_dict, init_filters  , init_len)
        curr = self.conv_res(curr, y_dict, init_filters*2, init_len//2)
        curr = self.conv_res(curr, y_dict, init_filters*4, init_len//4)
        # curr = self.conv_res(curr, y_dict, init_filters*8, init_len//8)
        
        # curr = self.deconv_res(curr, y_dict, init_filters*4  , init_len//8)
        curr = self.deconv_res(curr, y_dict, init_filters*2, init_len//4)
        curr = self.deconv_res(curr, y_dict, init_filters*1, init_len//2)
        curr = self._addNonLocalBlock(curr)
        curr = self.deconv_res(curr, y_dict, init_filters//2, init_len)
        curr = self.deconv_res(curr, y_dict, init_filters//4, init_len*2)

        return Model(inputs=[x] + list(y_dict.values()), outputs=[curr])
    
    
    def conv_res(self, x, y_dict, filters, length):
        curr = BatchNormalization(axis=3)(x)
        curr = Activation('leakyReLu')(curr)
        curr = Conv2D(filters, KERNEL_SIZE, strides=CONV_STRIDES, padding='same', kernel_initializer='truncated_normal')(curr)
        curr = self._addResBlock(curr, y_dict, filters, length//2)
        return curr
    
    def deconv_res(self, x, y_dict, filters, length):
        curr = BatchNormalization(axis=3)(x)
        curr = Activation('leakyReLu')(curr)
        curr = Conv2DTranspose(filters, KERNEL_SIZE, strides=CONV_STRIDES, padding='same', kernel_initializer='truncated_normal')(curr)
        curr = self._addResBlock(curr, y_dict, filters, length)
        return curr

    # ResBlock w repetition=2
    def _addResBlock(self, x, y_dict, filters_in, x_len):
        curr = x
        
        for _ in range(NUM_REPETITIONS):
            curr = Concatenate(axis=3)([x, y_dict[x_len]])
            curr = BatchNormalization(axis=3)(x)
            curr = Activation('leakyReLu')(curr)
            curr = Conv2D(filters_in, KERNEL_SIZE, strides=(1, 1), padding='same', kernel_initializer='truncated_normal')(curr)
        
        return Add()([curr, x])
    
    # Embedded Guassaian NonLocal Block
    def _addNonLocalBlock(self, x, compression=2):
        _, dim1, dim2, channels = K.int_shape(x)
        intermediate_dim = channels // 2
        
        # theta 
        theta = Conv2D(intermediate_dim, KERNEL_SIZE, strides=(1, 1), padding='same', kernel_initializer='truncated_normal')(x)
        theta = Reshape((-1, intermediate_dim))(theta)
        # phi 
        phi = Conv2D(intermediate_dim, KERNEL_SIZE, strides=(1, 1), padding='same', kernel_initializer='truncated_normal')(x)
        phi = Reshape((-1, intermediate_dim))(phi)
        phi = MaxPool1D(compression)(phi)
        
        # f
        f = Dot(axes=2)([theta, phi])
        f = Activation('softmax')(f)

        # g
        g = Conv2D(intermediate_dim, KERNEL_SIZE, strides=(1, 1), padding='same', kernel_initializer='truncated_normal')(x)
        g = Reshape((-1, intermediate_dim))(g)
        g = MaxPool1D(compression)(g)

        out = Dot(axes=(2, 1))([f, g])
        out = Reshape((dim1, dim2, intermediate_dim))(out)
        out = Conv2D(channels, KERNEL_SIZE, strides=(1, 1), padding='same', kernel_initializer='truncated_normal')(out)

        # residual connection
        return Add()([x, out])

    
    # should label generated samples -1 and real samples 1
    def wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    # need to generate random weighted-averages of real and generated samples, to feed the discriminator
    # and use for the gradient norm penalty.
    def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
        gradients = K.gradients(y_pred, averaged_samples)[0]
        gradients_sqr = K.square(gradients)
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)
        
    '''
    def _addDeconvBlock(self, model, activations, filters, kernel_size=KERNEL_SIZE, strides=CONV_STRIDES):
        model.add(BatchNormalization(axis=3))
        model.add(Activation(activation_func(activations)))
        
        model.add(UpSampling2D(size=2))
        model.add(Conv2D(filters=filters, kernel_size=kernel_size, strides=(1,1), padding='same', kernel_initializer='truncated_normal'))
        #model.add(Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding='same'))

    def _addConvBlock(self, model, activations, filters, kernel_size, strides, input_layer=False):
        if not input_layer:
            model.add(BatchNormalization(axis=3))
            model.add(Activation(activation_func(activations)))
            model.add(Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer='truncated_normal'))
        else:
            model.add(Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same', input_shape=INPUT_SHAPE, kernel_initializer='truncated_normal'))
    
    
    def train(self):
        # TODO: implements training process
        valid = np.ones((BATCH_SIZE, 1)) * .9
        fake  = np.zeros((BATCH_SIZE, 1))
        
        # self.discriminator_x.summary()
        # self.discriminator_y.summary()
        # self.generators.summary()
        
        for epoch in range(0,10):
            iterator = self.xy_Dataset.make_initializable_iterator()
            (x_next, y_next) = iterator.get_next()
            self.session.run(iterator.initializer)
            batch_i = 0
            
            while True:
                try:
                    # x_train, y_train = np.random.normal(size=[BATCH_SIZE, 256, 256, 3]), np.random.normal(size=[BATCH_SIZE, 256, 256, 3])
                    x_train, y_train = self.session.run([x_next, y_next])
                    
                    if x_train.shape[0] != BATCH_SIZE:
                        break
                    
                    x_train = (x_train / 255.0 - .5) * 2
                    y_train = (y_train / 255.0 - .5) * 2
                    
                    with tf.device('/device:GPU:1'):
                        x_valid, y_valid, x_, y_, x__, y__, x_identity, y_identity = self.generators.predict([x_train, y_train])
                        
                        d_x_real_loss = self.discriminator_x.train_on_batch(x_train, valid)
                        d_x_fake_loss = self.discriminator_x.train_on_batch(x_, fake)
                        d_x_loss = 0.5 * np.add(d_x_real_loss, d_x_fake_loss)

                        d_y_real_loss = self.discriminator_y.train_on_batch(y_train, valid)
                        d_y_fake_loss = self.discriminator_y.train_on_batch(y_, fake)
                        d_y_loss = 0.5 * np.add(d_y_real_loss, d_y_fake_loss)

                        # Total disciminator loss
                        d_loss = 0.5 * np.add(d_x_loss, d_y_loss)
                        
                    
                        # Total generator loss
                        g_loss = self.generators.train_on_batch([x_train, y_train],
                                                                [valid, valid,
                                                                 x_train, y_train,
                                                                 x_train, y_train,
                                                                 x_train, y_train])

                        if batch_i % 10 == 0:
                            plot_images(
                                x_train[0], y_train[0],
                                y_[0], x_[0],
                                x__[0], y__[0],
                                x_identity[0], y_identity[0]
                            )

                            print(
                                'Epoch: ', epoch,
                                'Batch: ', batch_i,
                                'Loss of Discriminator: ', d_loss, 
                                'Loss of Generator G: ', g_loss
                            )

                        batch_i += 1
                    
                
                except tf.errors.OutOfRangeError:
                    print('epoch ' + str( epoch) + ' end.')
                    break
    '''

    def test(self, x_test, y_test):
        # TODO: implements evaluation 
        pass

In [29]:
x = Input(shape=INPUT_SHAPE)
y = Input(shape=INPUT_SHAPE)
lens = [int(K.int_shape(y)[1] // 2**i) for i in [-2, -1, 0, 1, 2, 3, 4]]
y_dict = {i : Input(shape=(i, i, 3)) for i in lens}
y_dict = OrderedDict(sorted(y_dict.items(), key=lambda t: t[0]))
cycleGAN = CycleGAN()
#generator = cycleGAN.generator(x, y_dict)
d = cycleGAN.discriminator(x, y)

#Xs = np.random.normal(size=(10, 128, 128, 3))
#Ys = np.random.normal(size=(10, 128, 128, 3))
#y_keys = list(y_dict.keys())

#y_in = [np.array([cv2.resize(Y, (i, i)) for Y in Ys]) for i in y_keys]

#generator.summary()
d.summary()

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

In [None]:
generator.predict([Xs]+y_in).shape

In [None]:
generator.compile('adam', 'mse')

In [None]:
generator.fit([Xs]+y_in, np.random.normal(size=(10, 256, 256, 4)))