Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
FLAGS = flags.FLAGS

def main(_):
assert np.sqrt(FLAGS.sample_size) % 1 == 0., 'Flag `sample_size` needs to be a perfect square'
num_tiles = int(np.sqrt(FLAGS.sample_size))

# Print flags
for flag, _ in FLAGS.__flags.items():
print('"{}": {}'.format(flag, getattr(FLAGS, flag)))
Expand All @@ -62,8 +65,8 @@ def main(_):
with tf.device("/gpu:0"):

""" Define Models """
z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise')
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')
z = tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
real_images = tf.placeholder(tf.float32, [None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')

# Input noise into generator for training
net_g = generator(z, is_train=True, reuse=False)
Expand All @@ -77,12 +80,11 @@ def main(_):
net_g2 = generator(z, is_train=False, reuse=True)

""" Define Training Operations """
# cost for updating discriminator and generator
# discriminator: real images are labelled as 1
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')

# discriminator: images from generator (fake) are labelled as 0
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
# cost for updating discriminator
d_loss = d_loss_real + d_loss_fake

# generator: try to make the the fake images look real (1)
Expand Down Expand Up @@ -112,17 +114,16 @@ def main(_):

data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg")))
num_files = len(data_files)
shuffle = True

# Mini-batch generator
def iterate_minibatches():
def iterate_minibatches(batch_size, shuffle=True):
if shuffle:
indices = np.random.permutation(num_files)
for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size):
for start_idx in range(0, num_files - batch_size + 1, batch_size):
if shuffle:
excerpt = indices[start_idx: start_idx + FLAGS.batch_size]
excerpt = indices[start_idx: start_idx + batch_size]
else:
excerpt = slice(start_idx, start_idx + FLAGS.batch_size)
excerpt = slice(start_idx, start_idx + batch_size)
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0)
for file in data_files[excerpt]]).astype(np.float32)
Expand All @@ -136,13 +137,13 @@ def iterate_minibatches():
iter_counter = 0
for epoch in range(FLAGS.epoch):

sample_images = next(iterate_minibatches())
sample_images = next(iterate_minibatches(FLAGS.sample_size))
print("[*] Sample images updated!")

steps = 0
for batch_images in iterate_minibatches():
for batch_images in iterate_minibatches(FLAGS.batch_size):

batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
start_time = time.time()

# Updates the Discriminator(D)
Expand All @@ -162,7 +163,7 @@ def iterate_minibatches():
# Generate images
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images})
# Visualize generated images
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
tl.visualize.save_images(img, [num_tiles, num_tiles], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))

if np.mod(iter_counter, FLAGS.save_step) == 0:
Expand All @@ -171,10 +172,13 @@ def iterate_minibatches():
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
print("[*] Saving checkpoints SUCCESS!")

steps += 1

sess.close()

if __name__ == '__main__':
tf.app.run()
try:
tf.app.run()
except KeyboardInterrupt:
print('EXIT')