Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
import tensorlayer as tl

from glob import glob
from random import shuffle

from model import generator_simplified_api, discriminator_simplified_api
from utils import get_image
from model import generator, discriminator

# Defile TF Flags
# Define TF Flags
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
Expand Down Expand Up @@ -67,15 +66,15 @@ def main(_):
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')

# Input noise into generator for training
net_g, g_logits = generator_simplified_api(z, is_train=True, reuse=False)
net_g = generator(z, is_train=True, reuse=False)

# Input real and generated fake images into discriminator for training
net_d, d_logits = discriminator_simplified_api(net_g.outputs, is_train=True, reuse=False)
net_d2, d2_logits = discriminator_simplified_api(real_images, is_train=True, reuse=True)
net_d, d_logits = discriminator(net_g.outputs, is_train=True, reuse=False)
_, d2_logits = discriminator(real_images, is_train=True, reuse=True)

# Input noise into generator for evaluation
# set is_train to False so that BatchNormLayer behave differently
net_g2, g2_logits = generator_simplified_api(z, is_train=False, reuse=True)
net_g2 = generator(z, is_train=False, reuse=True)

""" Define Training Operations """
# cost for updating discriminator and generator
Expand Down Expand Up @@ -111,51 +110,59 @@ def main(_):
net_g_name = os.path.join(save_dir, 'net_g.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz')

data_files = glob(os.path.join("./data", FLAGS.dataset, "*.jpg"))
data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg")))
num_files = len(data_files)
shuffle = True

# Mini-batch generator
def iterate_minibatches():
if shuffle:
indices = np.random.permutation(num_files)
for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size):
if shuffle:
excerpt = indices[start_idx: start_idx + FLAGS.batch_size]
else:
excerpt = slice(start_idx, start_idx + FLAGS.batch_size)
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0)
for file in data_files[excerpt]]).astype(np.float32)

batch_steps = min(num_files, FLAGS.train_size) // FLAGS.batch_size

sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
# sample noise
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)

""" Training models """
iter_counter = 0
for epoch in range(FLAGS.epoch):

# Shuffle data
shuffle(data_files)

# Update sample files based on shuffled data
sample_files = data_files[0:FLAGS.sample_size]
sample = [get_image(sample_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for sample_file in sample_files]
sample_images = np.array(sample).astype(np.float32)
sample_images = next(iterate_minibatches())
print("[*] Sample images updated!")

steps = 0
for batch_images in iterate_minibatches():

# Load image data
batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size

for idx in range(0, batch_idxs):
batch_files = data_files[idx*FLAGS.batch_size:(idx + 1) * FLAGS.batch_size]

# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files]
batch_images = np.array(batch).astype(np.float32)
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
start_time = time.time()

# Updates the Discriminator(D)
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images })
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images})

# Updates the Generator(G)
# run generator twice to make sure that d_loss does not go to zero (different from paper)
for _ in range(2):
errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z})

end_time = time.time() - start_time
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG))
% (epoch, FLAGS.epoch, steps, batch_steps, end_time, errD, errG))

iter_counter += 1
if np.mod(iter_counter, FLAGS.sample_step) == 0:
# Generate images
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images})
# Visualize generated images
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))

if np.mod(iter_counter, FLAGS.save_step) == 0:
Expand All @@ -164,6 +171,8 @@ def main(_):
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
print("[*] Saving checkpoints SUCCESS!")

steps += 1

sess.close()

Expand Down
26 changes: 12 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@
flags = tf.app.flags
FLAGS = flags.FLAGS

def generator_simplified_api(inputs, is_train=True, reuse=False):
def generator(inputs, is_train=True, reuse=False):
image_size = 64
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
w_init = tf.random_normal_initializer(stddev=0.02)
s16 = image_size // 16
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
w_init = tf.glorot_normal_initializer()
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("generator", reuse=reuse):

net_in = InputLayer(inputs, name='g/in')
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
net_h0 = DenseLayer(net_in, n_units=(gf_dim * 8 * s16 * s16), W_init=w_init,
act = tf.identity, name='g/h0/lin')
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h0/batch_norm')

net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), strides=(2, 2),
net_h1 = DeConv2d(net_h0, gf_dim * 4, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h1/batch_norm')

net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), strides=(2, 2),
net_h2 = DeConv2d(net_h1, gf_dim * 2, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h2/batch_norm')
Expand All @@ -47,14 +47,12 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):

net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
logits = net_h4.outputs
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
return net_h4, logits
return net_h4

def discriminator_simplified_api(inputs, is_train=True, reuse=False):
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
w_init = tf.random_normal_initializer(stddev=0.02)
def discriminator(inputs, is_train=True, reuse=False):
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
w_init = tf.glorot_normal_initializer()
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("discriminator", reuse=reuse):
Expand Down
15 changes: 7 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from random import shuffle

import scipy.misc
import imageio as io
import numpy as np

def center_crop(x, crop_h, crop_w=None, resize_w=64):
Expand All @@ -18,27 +17,27 @@ def merge(images, size):
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
img[j * h: j * h + h, i * w: i * w + w, :] = image
return img

def transform(image, npx=64, is_crop=True, resize_w=64):
if is_crop:
cropped_image = center_crop(image, npx, resize_w=resize_w)
else:
cropped_image = image
return np.array(cropped_image)/127.5 - 1.
return (np.array(cropped_image) / 127.5) - 1.

def inverse_transform(images):
return (images+1.)/2.
return (images + 1.) / 2.

def imread(path, is_grayscale = False):
if (is_grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
return io.imread(path).astype(np.float).flatten()
else:
return scipy.misc.imread(path).astype(np.float)
return io.imread(path).astype(np.float)

def imsave(images, size, path):
return scipy.misc.imsave(path, merge(images, size))
return io.imsave(path, merge(images, size))

def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
Expand Down