In [None]:
import pandas as pd
import os
import tensorflow as tf

import scipy
from scipy.io import loadmat
import re

import string
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import *
import random
import time

import warnings
from pathlib import Path
import h5py
warnings.filterwarnings('ignore')

In [None]:
#loading skip thoughts vectors
data_path = './dataset'
df_train = pd.read_pickle(data_path + '/text2ImgData.pkl')

#replace original captions with skip thoughts vectors
filename = data_path + '/sample_caption_vectors.hdf5'
f = h5py.File(filename, 'r')
a_group_key = list(f.keys())[0]
train_caption = np.array(f[a_group_key])

filename = data_path + '/test_caption.hdf5'
f = h5py.File(filename, 'r')
a_group_key = list(f.keys())[0]
test_caption = np.array(f[a_group_key])

filename = data_path + '/sample_caption.hdf5'
f = h5py.File(filename, 'r')
a_group_key = list(f.keys())[0]
sample_caption = list(f[a_group_key])

In [None]:
#skip thoughts index + RNN encoding
df_train = pd.read_pickle(data_path + '/train_rnn.pkl')

for i, captions in enumerate(df_train['Captions']):
    caption_list = []
    for j, c in enumerate(captions):
        caption_list.append(np.hstack(([i*10+j], c)))
    df_train.iloc[i, df_train.columns.get_loc('Captions')] = caption_list
df_train.to_pickle('dataset/train_caption.pkl')

df_test = pd.read_pickle(data_path + '/test_rnn.pkl')
captions = []
for i in range(len(df_test)):
    captions.append(np.hstack(([i], df_test.iloc[i, df_test.columns.get_loc('Captions')])))

d = {'Captions': captions, 'ID': df_test['ID'].values}
df_test = pd.DataFrame(data=d)
df_test.to_pickle('dataset/test_caption.pkl') 

sample_rnn = np.load('dataset/sample_rnn.npy')

In [None]:
def transform_image(input_image, shape):
    float_img = tf.image.convert_image_dtype(input_image, tf.float32)
    float_img.set_shape([None, None, 3])

    short_side = tf.minimum(tf.shape(float_img)[0], tf.shape(float_img)[1])
    float_img = tf.image.resize_image_with_crop_or_pad(float_img, short_side, short_side)
    
    image = tf.image.resize_images(float_img, size=[shape[0] * 76 // 64, shape[1] * 76 // 64])
    image = tf.image.random_flip_left_right(image)
    image = tf.random_crop(image, shape)
    image = image * 2 - 1.
    return image

def training_data_generator(caption, image_path):
    imagefile = tf.read_file(data_path + image_path)
    image = tf.image.decode_image(imagefile, channels=3)
    image = transform_image(image, [64, 64, 3])
    return image, caption

def training_data_generator_hr(caption, image_path):
    imagefile = tf.read_file(data_path + image_path)
    image = tf.image.decode_image(imagefile, channels=3)
    image = transform_image(image, [256, 256, 3])
    return image, caption

def data_iterator(filenames, batch_size, data_generator):
    # Load the training data into two NumPy arrays
    df = pd.read_pickle(filenames)
    captions = df['Captions'].values
    image_path = df['ImagePath'].values

    caption = []
    paths = []

    for i in range(len(captions)):
        #caption.append(np.array(captions[i])[:, 1:])
        caption.append(np.hstack((np.array(captions[i])[:, 1:], 
                                  train_caption[np.array(captions[i])[:, 0].astype(np.int32)][:, 2400:])))
        paths.append(image_path[i])
        
    caption = np.array(caption)
    paths = np.array(paths)
    
    # Assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == paths.shape[0]

    dataset = tf.data.Dataset.from_tensor_slices((caption, paths))
    
    dataset = dataset.map(data_generator, num_parallel_calls=4)
    
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2 * batch_size)
    dataset = dataset.batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes

    return iterator, output_types, output_shapes

In [None]:
class Generator:

    def __init__(self, noise_z, text, training_phase, hparas, reuse):
        self.z = noise_z
        self.text = text
        self.batch_size = tf.shape(text)[0]
        self.train = training_phase
        self.hparas = hparas
        self.gf_dim = 128
        self.reuse = reuse
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.deconv2d_initializer = tf.random_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self._build_model()

    def _generate_condition(self, embed):
        conditions = tf.layers.dense(embed, self.hparas['TEXT_DIM'] * 2, activation=tf.nn.leaky_relu,
                                     kernel_initializer=self.fc_initializer)
        mean = conditions[:, :self.hparas['TEXT_DIM']]
        log_sigma = conditions[:, self.hparas['TEXT_DIM']:]
        epsilon = tf.truncated_normal(tf.shape(mean))
        stddev = tf.exp(log_sigma)
        c = mean + stddev * epsilon
        KLloss = -log_sigma + 0.5 * (-1 + tf.exp(2. * log_sigma) + tf.square(mean))
        KLloss = tf.reduce_mean(KLloss)
        return c, KLloss
    
    def _deconv_bn_relu(self, inputs, W, out_shape):
        node = tf.nn.conv2d_transpose(inputs, W, output_shape=out_shape, strides=[1,2,2,1], padding='SAME')
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.relu(node)
        return node

    def _build_model(self):
        with tf.variable_scope('generator_1', reuse=self.reuse): 
            c, self.KLloss = self._generate_condition(self.text)
            
            z_text_concat = tf.concat([c, self.z], axis=1)
            g_net = tf.layers.dense(z_text_concat, self.s16 * self.s16 * self.gf_dim * 8, 
                                    kernel_initializer=self.fc_initializer)
            g_net = tf.reshape(g_net, [-1, self.s16, self.s16, self.gf_dim * 8])
            g_net = tf.layers.batch_normalization(g_net, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
            g_net = tf.nn.relu(g_net)
            
            W1 = tf.get_variable("weights1", shape=[4, 4, self.gf_dim * 4, self.gf_dim * 8], initializer=self.deconv2d_initializer)
            g_net = self._deconv_bn_relu(g_net, W1, [self.batch_size, self.s8, self.s8, self.gf_dim * 4])
            
            W2 = tf.get_variable("weights2", shape=[4, 4, self.gf_dim * 2, self.gf_dim * 4], initializer=self.deconv2d_initializer)
            g_net = self._deconv_bn_relu(g_net, W2, [self.batch_size, self.s4, self.s4, self.gf_dim * 2])
            
            W3 = tf.get_variable("weights3", shape=[4, 4, self.gf_dim, self.gf_dim * 2], initializer=self.deconv2d_initializer)
            g_net = self._deconv_bn_relu(g_net, W3, [self.batch_size, self.s2, self.s2, self.gf_dim])
            
            W4 = tf.get_variable("weights4", shape=[4, 4, 3, self.gf_dim], initializer=self.deconv2d_initializer)
            g_net = tf.nn.conv2d_transpose(g_net, W4, output_shape=[self.batch_size, self.s, self.s, 3], 
                                           strides=[1,2,2,1], padding='SAME')
            g_net = tf.nn.tanh(g_net)

            self.outputs = g_net

In [None]:
class Discriminator:

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.df_dim = 64
        self.reuse = reuse
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.conv2d_initializer = tf.truncated_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self._build_model()
        
    def _conv_bn_leaky(self, inputs, out_dim, ks, ss):
        node = tf.layers.conv2d(inputs, out_dim, kernel_size=ks, strides=ss,
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.leaky_relu(node)
        return node

    def _build_model(self):
        with tf.variable_scope('discriminator_1', reuse=self.reuse):
            text_decode = tf.layers.dense(self.text, self.hparas['TEXT_DIM'], activation=tf.nn.leaky_relu,
                                          kernel_initializer=self.fc_initializer)
            text_decode = tf.expand_dims(tf.expand_dims(text_decode, 1), 1)
            text_decode = tf.tile(text_decode, [1, self.s16, self.s16, 1])
            
            image_decode = tf.layers.conv2d(self.image, self.df_dim , kernel_size=[4, 4], strides=[2, 2],
                                            padding='SAME', activation=tf.nn.leaky_relu, 
                                            use_bias=False, kernel_initializer=self.conv2d_initializer)
            
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 2, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 4, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 8, [4, 4], [2, 2])
            
            concat_decode = tf.concat([image_decode, text_decode], axis=3)
            concat_decode = self._conv_bn_leaky(concat_decode,  self.df_dim * 8, [1, 1], [1, 1])

            concat_decode = tf.layers.conv2d(concat_decode, 1, kernel_size=[self.s16, self.s16], strides=[self.s16, self.s16],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)

            self.logits = concat_decode

In [None]:
class Generator_Complex:

    def __init__(self, noise_z, text, training_phase, hparas, reuse):
        self.z = noise_z
        self.text = text
        self.batch_size = tf.shape(text)[0]
        self.train = training_phase
        self.hparas = hparas
        self.gf_dim = 128
        self.reuse = reuse
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.deconv2d_initializer = tf.random_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self._build_model()

    def _generate_condition(self, embed):
        conditions = tf.layers.dense(embed, self.hparas['TEXT_DIM'] * 2, activation=tf.nn.leaky_relu,
                                     kernel_initializer=self.fc_initializer)
        mean = conditions[:, :self.hparas['TEXT_DIM']]
        log_sigma = conditions[:, self.hparas['TEXT_DIM']:]
        epsilon = tf.truncated_normal(tf.shape(mean))
        stddev = tf.exp(log_sigma)
        c = mean + stddev * epsilon
        KLloss = -log_sigma + 0.5 * (-1 + tf.exp(2. * log_sigma) + tf.square(mean))
        KLloss = tf.reduce_mean(KLloss)
        return c, KLloss
    
    def _conv_bn_relu(self, inputs, out_shape, ks, ss):
        node = tf.layers.conv2d(inputs, filters=out_shape , kernel_size=ks, strides=ss, padding='SAME', use_bias=False)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.relu(node)
        return node
    
    def _conv_bn(self, inputs, out_shape, ks, ss):
        node = tf.layers.conv2d(inputs, filters = out_shape , kernel_size=ks, strides=ss, padding='SAME', use_bias=False)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        return node

    def _build_model(self):
        with tf.variable_scope('generator_1', reuse=self.reuse): 
            c, self.KLloss = self._generate_condition(self.text)
            
            z_text_concat = tf.concat([c, self.z], axis=1)
            node1_0 = tf.layers.dense(z_text_concat, self.s16 * self.s16 * self.gf_dim * 8,
                                      kernel_initializer=self.fc_initializer)
            node1_0 = tf.layers.batch_normalization(node1_0, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
            node1_0 = tf.reshape(node1_0, [-1, self.s16, self.s16, self.gf_dim * 8])
            
            node1_1 = self._conv_bn_relu(node1_0, self.gf_dim * 2, [1, 1], [1, 1])
            node1_1 = self._conv_bn_relu(node1_1, self.gf_dim * 2, [3, 3], [1, 1])
            node1_1 = self._conv_bn(node1_1, self.gf_dim * 8, [3, 3], [1, 1])            
                        
            node1 = tf.add(node1_0, node1_1)
            node1 = tf.nn.relu(node1)
            
            node2_0 = tf.image.resize_nearest_neighbor(node1, size=[self.s8, self.s8])
            node2_0 = self._conv_bn(node2_0, self.gf_dim * 4, [3, 3], [1, 1])  
            node2_1 = self._conv_bn_relu(node2_0, self.gf_dim * 1, [1, 1], [1, 1])
            node2_1 = self._conv_bn_relu(node2_1, self.gf_dim * 1, [3, 3], [1, 1])
            node2_1 = self._conv_bn(node2_1, self.gf_dim * 4, [3, 3], [1, 1])  
            node2 = tf.add(node2_0, node2_1)
            node2 = tf.nn.relu(node2)
            
            output_tensor = tf.image.resize_nearest_neighbor(node2, size=[self.s4, self.s4])
            output_tensor = self._conv_bn_relu(output_tensor, self.gf_dim * 2, [3, 3], [1, 1])
            output_tensor = tf.image.resize_nearest_neighbor(output_tensor, size=[self.s2, self.s2])         
            output_tensor = self._conv_bn_relu(output_tensor, self.gf_dim, [3, 3], [1, 1])
   
            output_tensor = tf.image.resize_nearest_neighbor(output_tensor, size=[self.s, self.s])
            output_tensor = tf.layers.conv2d(output_tensor, 3, kernel_size=[3, 3], strides=[1, 1], 
                                             padding ='SAME', use_bias=False)
            output_tensor = tf.nn.tanh(output_tensor)
            
            self.outputs = output_tensor

In [None]:
class Discriminator_Complex:

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.df_dim = 64 
        self.reuse = reuse
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.conv2d_initializer = tf.truncated_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self._build_model()
        
    def _conv_bn_leaky(self, inputs, out_dim, ks, ss):
        node = tf.layers.conv2d(inputs, out_dim, kernel_size=ks, strides=ss,
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.leaky_relu(node)
        return node
    
    def _conv_bn(self, inputs, out_dim, ks, ss):
        node = tf.layers.conv2d(inputs, out_dim, kernel_size=ks, strides=ss,
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        return node

    def _build_model(self):
        with tf.variable_scope('discriminator_1', reuse=self.reuse):
            text_decode = tf.layers.dense(self.text, self.hparas['TEXT_DIM'], activation=tf.nn.leaky_relu,
                                          kernel_initializer=self.fc_initializer)
            text_decode = tf.expand_dims(tf.expand_dims(text_decode, 1), 1)
            text_decode = tf.tile(text_decode, [1, self.s16, self.s16, 1]) 
            
            node1_0 = tf.layers.conv2d(self.image, self.df_dim , kernel_size=[4, 4], strides=[2, 2],
                                       padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
            node1_0 = tf.nn.leaky_relu(node1_0)
            
            node1_0 = self._conv_bn_leaky(node1_0, self.df_dim * 2, [4, 4], [2, 2])
            node1_0 = self._conv_bn(node1_0, self.df_dim * 4, [4, 4], [2, 2])
            node1_0 = self._conv_bn(node1_0, self.df_dim * 8, [4, 4], [2, 2])
            
            node1_1 = self._conv_bn_leaky(node1_0, self.df_dim * 2, [1, 1], [1, 1])
            node1_1 = self._conv_bn_leaky(node1_1, self.df_dim * 2, [3, 3], [1, 1])
            node1_1 = self._conv_bn(node1_1, self.df_dim * 8, [3, 3], [1, 1])
                        
            node1 = tf.add(node1_0, node1_1)
            node1 = tf.nn.leaky_relu(node1)
            
            concat_decode = tf.concat([node1, text_decode], axis=3)
            
            concat_decode = self._conv_bn_leaky(concat_decode, self.df_dim * 8, [1, 1], [1, 1])
            concat_decode = tf.layers.conv2d(concat_decode, 1, kernel_size=[self.s16, self.s16], strides=[self.s16, self.s16],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)

            self.logits = concat_decode
            

In [None]:
class Generator_2:

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.batch_size = tf.shape(text)[0]
        self.train = training_phase
        self.hparas = hparas
        self.gf_dim = 128
        self.reuse = reuse
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.conv2d_initializer = tf.truncated_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self._build_model()

    def _generate_condition(self, inputs):
        conditions = tf.layers.dense(inputs, self.hparas['TEXT_DIM'] * 2, activation=tf.nn.leaky_relu,
                                     kernel_initializer=self.fc_initializer)
        mean = conditions[:, :self.hparas['TEXT_DIM']]
        log_sigma = conditions[:, self.hparas['TEXT_DIM']:]
        epsilon = tf.truncated_normal(tf.shape(mean))
        stddev = tf.exp(log_sigma)
        c = mean + stddev * epsilon
        KLloss = -log_sigma + 0.5 * (-1 + tf.exp(2. * log_sigma) + tf.square(mean))
        KLloss = tf.reduce_mean(KLloss)
        return c, KLloss
    
    def _conv_bn_relu(self, inputs, out_dim, ks, ss):
        node = tf.layers.conv2d(inputs, out_dim, kernel_size=ks, strides=ss,
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.leaky_relu(node)
        return node
    
    def _residual_block(self, inputs):
        node = self._conv_bn_relu(inputs, self.gf_dim * 4, [3, 3], [1, 1])
        node = tf.layers.conv2d(node, self.gf_dim * 4, kernel_size=[3, 3], strides=[1, 1],
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = node + inputs
        node = tf.nn.relu(node)
        return node

    def _build_model(self):
        with tf.variable_scope('generator_2', reuse=self.reuse): 

            #ecnode image
            image_encode = tf.layers.conv2d(self.image, self.gf_dim, kernel_size=[3, 3], strides=[1, 1],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
            image_encode = tf.nn.relu(image_encode)
            image_encode = self._conv_bn_relu(image_encode, self.gf_dim * 2, [4, 4], [2, 2])
            image_encode = self._conv_bn_relu(image_encode, self.gf_dim * 4, [4, 4], [2, 2])
            
            c, self.KLloss = self._generate_condition(self.text)
            c = tf.expand_dims(tf.expand_dims(c, 1), 1)
            c = tf.tile(c, [1, self.s4, self.s4, 1])
            
            g_net = tf.concat([image_encode, c], axis=3)
            g_net = self._conv_bn_relu(g_net, self.gf_dim * 4, [3, 3], [1, 1])
            
            g_net = self._residual_block(g_net)
            g_net = self._residual_block(g_net)
            g_net = self._residual_block(g_net)
            g_net = self._residual_block(g_net)
            
            g_net = tf.image.resize_nearest_neighbor(g_net, [self.s2, self.s2])
            g_net = self._conv_bn_relu(g_net, self.gf_dim * 2, [3, 3], [1, 1])
            g_net = tf.image.resize_nearest_neighbor(g_net, [self.s, self.s])
            g_net = self._conv_bn_relu(g_net, self.gf_dim, [3, 3], [1, 1])
            g_net = tf.image.resize_nearest_neighbor(g_net, [self.s * 2, self.s * 2])
            g_net = self._conv_bn_relu(g_net, self.gf_dim // 2, [3, 3], [1, 1])
            g_net = tf.image.resize_nearest_neighbor(g_net, [self.s * 4, self.s * 4])
            g_net = self._conv_bn_relu(g_net, self.gf_dim // 4, [3, 3], [1, 1])
            
            g_net = tf.layers.conv2d(g_net, 3, kernel_size=[3, 3], strides=[1, 1],
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
            g_net = tf.nn.tanh(g_net)

            self.outputs = g_net

In [None]:
class Discriminator_2:

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.df_dim = 64  
        self.reuse = reuse
        self.fc_initializer = tf.random_normal_initializer(0, 0.02)
        self.conv2d_initializer = tf.truncated_normal_initializer(0, 0.02)
        self.gamma_initializer = tf.random_normal_initializer(1, 0.02)
        self.s = self.hparas['LR_SIZE']
        self.s2, self.s4, self.s8, self.s16 = self.s // 2, self.s // 4, self.s // 8, self.s // 16
        self._build_model()
        
    def _conv_bn_leaky(self, inputs, out_dim, ks, ss):
        node = tf.layers.conv2d(inputs, out_dim, kernel_size=ks, strides=ss,
                                padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
        node = tf.layers.batch_normalization(node, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
        node = tf.nn.leaky_relu(node)
        return node

    def _build_model(self):
        with tf.variable_scope('discriminator_2', reuse=self.reuse):
            text_decode = tf.layers.dense(self.text, self.hparas['TEXT_DIM'], activation=tf.nn.leaky_relu,
                                          kernel_initializer=self.fc_initializer)
            text_decode = tf.expand_dims(tf.expand_dims(text_decode, 1), 1)
            text_decode = tf.tile(text_decode, [1, self.s16, self.s16, 1])
            
            image_decode = tf.layers.conv2d(self.image, self.df_dim, kernel_size=[4, 4], strides=[2, 2],
                                            padding='SAME', activation=tf.nn.leaky_relu, 
                                            use_bias=False, kernel_initializer=self.conv2d_initializer)
            
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 2, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 4, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 8, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 16, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 32, [4, 4], [2, 2])
            image_decode = self._conv_bn_leaky(image_decode, self.df_dim * 16, [1, 1], [1, 1])
            image_decode = tf.layers.conv2d(image_decode, self.df_dim * 8, kernel_size=[1, 1], strides=[1, 1],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
            image_decode = tf.layers.batch_normalization(image_decode, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
            
            image_decode_2 = self._conv_bn_leaky(image_decode, self.df_dim * 2, [1, 1], [1, 1])
            image_decode_2 = self._conv_bn_leaky(image_decode_2, self.df_dim * 2, [3, 3], [1, 1])
            image_decode_2 = tf.layers.conv2d(image_decode_2, self.df_dim * 8, kernel_size=[3, 3], strides=[1, 1],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)
            image_decode_2 = tf.layers.batch_normalization(image_decode_2, training=self.train, momentum=0.9, epsilon=1e-5, 
                                             gamma_initializer=self.gamma_initializer)
            image_decode_2 = image_decode + image_decode_2
            image_decode_2 = tf.nn.leaky_relu(image_decode_2)
            
            concat_decode = tf.concat([image_decode_2, text_decode], axis=3)
            concat_decode = self._conv_bn_leaky(concat_decode, self.df_dim * 8, [1, 1], [1, 1])

            concat_decode = tf.layers.conv2d(concat_decode, 1, kernel_size=[self.s16, self.s16], strides=[self.s16, self.s16],
                                            padding='SAME', use_bias=False, kernel_initializer=self.conv2d_initializer)

            self.logits = concat_decode

In [None]:
class GAN:

    def __init__(self,
                 hparas,
                 training_phase,
                 stage,
                 dataset_path,
                 ckpt_path,
                 inference_path,
                 recover=None):
        self.hparas = hparas
        self.train = training_phase
        self.stage = stage
        self.dataset_path = dataset_path  # dataPath+'/text2ImgData.pkl'
        self.ckpt_path = ckpt_path
        self.sample_path = './samples'
        self.inference_path = './inference'

        self._get_session()  # get session
        self._get_train_data_iter()  # initialize and get data iterator
        self._input_layer()  # define input placeholder
        self._get_inference()  # build generator and discriminator
        self._get_loss()  # define gan loss
        self._get_var_with_name()  # get variables for each part of model
        self._optimize()  # define optimizer
        self._init_vars()
        self._get_saver()
        self.recover = recover

        if recover is not None:
            self._load_checkpoint(recover)

    def _get_train_data_iter(self):
        if self.train:  # training data iteratot
            if self.stage == 1:
                iterator_train, types, shapes = data_iterator('dataset/train_caption.pkl', 
                    self.hparas['BATCH_SIZE'], training_data_generator)
            else:
                iterator_train, types, shapes = data_iterator('dataset/train_caption.pkl', 
                    self.hparas['BATCH_SIZE'], training_data_generator_hr)
            iter_initializer = iterator_train.initializer
            self.next_element = iterator_train.get_next()
            self.sess.run(iterator_train.initializer)
            self.iterator_train = iterator_train
        else:  # testing data iterator
            iterator_train, types, shapes = data_iterator_test(
                self.dataset_path + '/test_caption.pkl', self.hparas['BATCH_SIZE'])
            iter_initializer = iterator_train.initializer
            self.next_element = iterator_train.get_next()
            self.sess.run(iterator_train.initializer)
            self.iterator_test = iterator_train

    def _input_layer(self):
        if self.train:
            self.real_image = tf.placeholder(dtype=tf.float32, shape=[
                  None, self.hparas['IMAGE_SIZE'][0],
                  self.hparas['IMAGE_SIZE'][1], self.hparas['IMAGE_SIZE'][2]
                ], name='real_image')
        self.caption = tf.placeholder(dtype=tf.float32, shape=[None, self.hparas['EMBED_DIM']], name='caption')
        self.z_noise = tf.placeholder(tf.float32, [None, self.hparas['Z_DIM']], name='z_noise')

    def _get_inference(self):
        if self.train:
            if self.stage == 1:
                # GAN training
                # generating image
                self.generator = Generator(self.z_noise, self.caption, training_phase=True,
                                           hparas=self.hparas, reuse=False)

                # discriminize
                # fake image with matched text
                self.fake_discriminator = Discriminator(self.generator.outputs, self.caption, training_phase=True,
                                                        hparas=self.hparas, reuse=False)

                # real image with real text
                self.real_discriminator = Discriminator(self.real_image, self.caption, training_phase=True,
                                                        hparas=self.hparas, reuse=True)

                # real image with mismatched text
                self.wrong_discriminator = Discriminator(tf.random_shuffle(self.real_image), self.caption, 
                                                         training_phase=True, hparas=self.hparas, reuse=True)
                
            #stage 2
            else:
                self.generator = Generator(self.z_noise, self.caption, training_phase=False,
                                           hparas=self.hparas, reuse=False)
                
                self.generator_2 = Generator_2(self.generator.outputs, self.caption, training_phase=True,
                                               hparas=self.hparas, reuse=False)

                self.fake_discriminator = Discriminator_2(self.generator_2.outputs, self.caption, training_phase=True,
                                                          hparas=self.hparas, reuse=False)

                self.real_discriminator = Discriminator_2(self.real_image, self.caption, training_phase=True,
                                                          hparas=self.hparas, reuse=True)

                self.wrong_discriminator = Discriminator_2(tf.random_shuffle(self.real_image), self.caption,
                                                           training_phase=True, hparas=self.hparas, reuse=True)
                
        else:  # inference mode
            self.generate_image_net = Generator(
                self.z_noise,
                self.caption,
                training_phase=False,
                hparas=self.hparas,
                reuse=False)
            if self.stage == 2:
                self.generate_image_net_2 = Generator_2(
                    self.generate_image_net.outputs,
                    self.caption,
                    training_phase=False,
                    hparas=self.hparas,
                    reuse=False)
                

    def _get_loss(self):
        if self.train:
            d_loss1 = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.real_discriminator.logits,
                    labels=tf.ones_like(self.real_discriminator.logits),
                    name='d_loss1'))
            d_loss2 = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.fake_discriminator.logits,
                    labels=tf.zeros_like(self.fake_discriminator.logits),
                    name='d_loss2'))
            d_loss3 = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.wrong_discriminator.logits,
                    labels=tf.zeros_like(self.wrong_discriminator.logits),
                    name='d_loss3'))
            self.d_loss = d_loss1 + (d_loss2 + d_loss3) / 2
            if self.stage == 1:
                self.g_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.fake_discriminator.logits,
                        labels=tf.ones_like(self.fake_discriminator.logits),
                        name='g_loss')) + 2 * self.generator.KLloss
            else:
                self.g_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.fake_discriminator.logits,
                        labels=tf.ones_like(self.fake_discriminator.logits),
                        name='g_loss')) + 2 * self.generator_2.KLloss

    def _optimize(self):
        if self.train:
            with tf.variable_scope('learning_rate'):
                self.lr_var = tf.Variable(self.hparas['LR'], trainable=False)
            
            discriminator_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            generator_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            if self.stage == 1:
                d_update = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'discriminator_1')
                with tf.control_dependencies(d_update):
                    self.d_optim = discriminator_optimizer.minimize(self.d_loss, var_list=self.discrim_vars)
                g_update = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'generator_1')
                with tf.control_dependencies(g_update):
                    self.g_optim = generator_optimizer.minimize(self.g_loss, var_list=self.generator_vars)           
            else:
                d_update = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'discriminator_2')
                with tf.control_dependencies(d_update):
                    self.d_optim = discriminator_optimizer.minimize(self.d_loss, var_list=self.discrim_vars_2)
                g_update = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'generator_2')
                with tf.control_dependencies(g_update):
                    self.g_optim = generator_optimizer.minimize(self.g_loss, var_list=self.generator_vars_2) 

    def training(self):

        for _epoch in range(self.hparas['N_EPOCH']):
            start_time = time.time()

            if _epoch != 0 and (_epoch % self.hparas['DECAY_EVERY'] == 0):
                new_lr_decay = self.hparas['LR_DECAY'] ** (_epoch // self.hparas['DECAY_EVERY'])
                self.sess.run(tf.assign(self.lr_var, self.hparas['LR'] * new_lr_decay))
                print("new lr %f" % (self.hparas['LR'] * new_lr_decay))

            n_batch_epoch = int(self.hparas['N_SAMPLE'] / self.hparas['BATCH_SIZE'])
            
            for _step in range(n_batch_epoch):
                step_time = time.time()

                image_batch, caption_batch  = self.sess.run(self.next_element)
                
                #caption_batch = caption_batch[:, np.random.choice(caption_batch.shape[1], 4, replace=False), :]
                #caption_batch = np.mean(caption_batch, axis=1)
                
                caption_tmp = np.zeros([caption_batch.shape[0], caption_batch.shape[2]])
                for i, batch in enumerate(caption_batch):
                    random_idx = np.random.choice(batch.shape[0], 4, replace=False)
                    mean_sum = 0.0
                    count = 0
                    for idx in random_idx:
                        if np.sum(batch[idx][:100]) != 0:
                            mean_sum += batch[idx]
                            count += 1
                            
                    if count != 0:
                        caption_tmp[i] = mean_sum / count
                    else:
                        for j in range(10):
                            if np.sum(batch[j][:100]) != 0:
                                caption_tmp[i] = batch[j]
                caption_batch = caption_tmp
                              
                batch_size = image_batch.shape[0]
                b_z = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, self.hparas['Z_DIM'])).astype(np.float32)
                # update discriminator
                self.discriminator_error, _ = self.sess.run([self.d_loss, self.d_optim],
                    feed_dict={self.real_image: image_batch, self.caption: caption_batch, self.z_noise: b_z})

                # update generate
                self.generator_error, _ = self.sess.run([self.g_loss, self.g_optim], 
                    feed_dict={self.real_image: image_batch, self.caption: caption_batch, self.z_noise: b_z})
                if _step % 50 == 0:
                    print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.3f, g_loss: %.3f"
                          % (_epoch, self.hparas['N_EPOCH'], _step, n_batch_epoch,
                             time.time() - step_time, self.discriminator_error, self.generator_error))
            if _epoch != 0 and (_epoch + 1) % 5 == 0:
                _epoch_total = _epoch
                if self.recover is not None:
                    _epoch_total = self.recover + 1 + _epoch
                self._save_checkpoint(_epoch_total)
                self._sample_visiualize(_epoch_total)

    def inference(self):
        for _iters in range(30):
            caption, idx = self.sess.run(self.next_element)
            batch_size = caption.shape[0]
            z_seed = np.random.normal(loc=0.0, scale=1.0, size=(batch_size, self.hparas['Z_DIM'])).astype(np.float32)
            z_seed = np.zeros((batch_size, self.hparas['Z_DIM'])).astype(np.float32)
            
            if self.stage == 1:
                img_gen = self.sess.run(self.generate_image_net.outputs, feed_dict={self.caption: caption, 
                                                                                    self.z_noise: z_seed})
            else:
                img_gen = self.sess.run(self.generate_image_net_2.outputs, feed_dict={self.caption: caption, 
                                                                                    self.z_noise: z_seed})
            img_gen = (img_gen + 1) / 2
            for i in range(batch_size):
                scipy.misc.imsave(self.inference_path + '/inference_{:04d}.png'.format(idx[i]), img_gen[i])

    def _init_vars(self):
        self.sess.run(tf.global_variables_initializer())

    def _get_session(self):
        self.sess = tf.Session()

    def _get_saver(self):
        self.g_saver = tf.train.Saver(var_list=self.generator_vars, max_to_keep=40)
        if self.stage == 1:
            if self.train:
                self.d_saver = tf.train.Saver(var_list=self.discrim_vars, max_to_keep=40)
        else:
            self.g_saver_2 = tf.train.Saver(var_list=self.generator_vars_2, max_to_keep=40)
            if self.train:
                self.d_saver_2 = tf.train.Saver(var_list=self.discrim_vars_2, max_to_keep=40)
            self.g_saver_2 = tf.train.Saver(var_list=self.generator_vars_2, max_to_keep=40)

    def _sample_visiualize(self, epoch):
        ni = int(np.ceil(np.sqrt(self.hparas['BATCH_SIZE'])))
        sample_size = self.hparas['BATCH_SIZE']
        sample_sentence = []
        for i in range(len(sample_rnn)):
            for _ in range(sample_size // ni):
                sample_sentence.append(np.hstack((sample_rnn[i], sample_caption[i][2400:])))

        sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, self.hparas['Z_DIM'])).astype(np.float32)
        
        if self.stage == 1:
            img_gen = self.sess.run(self.generator.outputs, feed_dict={self.caption: sample_sentence, 
                                                                       self.z_noise: sample_seed})
        else:
            img_gen = self.sess.run(self.generator_2.outputs, feed_dict={self.caption: sample_sentence, 
                                                                       self.z_noise: sample_seed})
        img_gen = (img_gen + 1) / 2
        img_gen = skimage.transform.resize(img_gen, (64, 64, 3), mode='constant')
        save_images(img_gen, [ni, ni], self.sample_path + '/train_{:02d}.png'.format(epoch))

    def _get_var_with_name(self):
        t_vars = tf.trainable_variables()
        generator_global = tf.global_variables('generator_1')
        generator_moving = [g for g in generator_global if 'moving_mean' in g.name or 'moving_variance' in g.name]
        
        discrim_global = tf.global_variables('discriminator_1')
        discrim_moving = [g for g in discrim_global if 'moving_mean' in g.name or 'moving_variance' in g.name]
        
        self.generator_vars = [var for var in t_vars if 'generator_1' in var.name] + generator_moving
        self.discrim_vars = [var for var in t_vars if 'discriminator_1' in var.name] + discrim_moving
        
        generator_global_2 = tf.global_variables('generator_2')
        generator_moving_2 = [g for g in generator_global_2 if 'moving_mean' in g.name or 'moving_variance' in g.name]
        
        discrim_global_2 = tf.global_variables('discriminator_2')
        discrim_moving_2 = [g for g in discrim_global_2 if 'moving_mean' in g.name or 'moving_variance' in g.name] 
            
        self.generator_vars_2 = [var for var in t_vars if 'generator_2' in var.name] + generator_moving_2
        self.discrim_vars_2 = [var for var in t_vars if 'discriminator_2' in var.name] + discrim_moving_2
 
            
    def _load_checkpoint(self, recover):
        self.g_saver.restore(self.sess,self.ckpt_path + 'g_model_' + str(recover) + '.ckpt')
        if self.stage == 1:
            generator_varsif self.train:
                self.d_saver.restore(self.sess,self.ckpt_path + 'd_model_' + str(recover) + '.ckpt')  
        elif tf.train.checkpoint_exists(self.ckpt_path + 'g_model_2_' + str(recover) + '.ckpt'):
            self.g_saver_2.restore(self.sess,self.ckpt_path + 'g_model_2_' + str(recover) + '.ckpt')
            if self.train:
                self.d_saver_2.restore(self.sess,self.ckpt_path + 'd_model_2_' + str(recover) + '.ckpt')
                
        print('-----success restored checkpoint--------')

    def _save_checkpoint(self, epoch):
        self.g_saver.save(self.sess,self.ckpt_path + 'g_model_' + str(epoch) + '.ckpt')   
        if self.stage == 1:
            self.d_saver.save(self.sess,self.ckpt_path + 'd_model_' + str(epoch) + '.ckpt')
        else:
            self.g_saver_2.save(self.sess,self.ckpt_path + 'g_model_2_' + str(epoch) + '.ckpt')
            self.d_saver_2.save(self.sess,self.ckpt_path + 'd_model_2_' + str(epoch) + '.ckpt')
        print('-----success saved checkpoint--------')

In [None]:
#Stage I Parameters
def get_hparas():
    hparas = {
        'EMBED_DIM': 2600,  # word embedding dimension
        'TEXT_DIM': 128,  # text embedding dimension
        'Z_DIM': 100,  # random noise z dimension
        'IMAGE_SIZE': [64, 64, 3],  # render image size
        'LR_SIZE': 64,
        'BATCH_SIZE': 64,
        'LR': 0.0002,
        'DECAY_EVERY': 50,
        'LR_DECAY': 0.5,
        'BETA': 0.5,  # AdamOptimizer parameter
        'N_EPOCH': 600,
        'N_SAMPLE': 7370
    }
    return hparas

checkpoint_path = './checkpoint/'
inference_path = './inference'

In [None]:
#Stage I Training
tf.reset_default_graph()

gan = GAN(
    get_hparas(),
    training_phase=True,
    stage=1,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path,
    recover=None)
gan.training()

In [None]:
#Stage II Parameters
def get_hparas_2():
    hparas = {
        'EMBED_DIM': 2600,  # word embedding dimension
        'TEXT_DIM': 128,  # text embedding dimension
        'Z_DIM': 100,  # random noise z dimension
        'IMAGE_SIZE': [256, 256, 3],  # render image size
        'LR_SIZE': 64,
        'BATCH_SIZE': 64,
        'LR': 0.0002,
        'DECAY_EVERY': 100,
        'LR_DECAY': 0.5,
        'BETA': 0.5,  # AdamOptimizer parameter
        'N_EPOCH': 600,
        'N_SAMPLE': 7370
    }
    return hparas

In [None]:
#Stage II Training
tf.reset_default_graph()

gan = GAN(
    get_hparas_2(),
    training_phase=True,
    stage=2,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path,
    recover=594)
gan.training()

In [None]:
def data_iterator_test(filenames, batch_size):
    data = pd.read_pickle(filenames)
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(np.hstack((np.array(captions[i])[1:], 
                                  test_caption[np.array(captions[i])[0].astype(np.int32)][2400:])))
        #caption.append(test_caption[captions[i]])
    caption = np.asarray(caption)
    index = data['ID'].values
    index = np.asarray(index)

    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes

    return iterator, output_types, output_shapes