diff --git a/main.py b/main.py index 6cd8fd1..0fb3c56 100755 --- a/main.py +++ b/main.py @@ -47,6 +47,9 @@ FLAGS = flags.FLAGS def main(_): + assert np.sqrt(FLAGS.sample_size) % 1 == 0., 'Flag `sample_size` needs to be a perfect square' + num_tiles = int(np.sqrt(FLAGS.sample_size)) + # Print flags for flag, _ in FLAGS.__flags.items(): print('"{}": {}'.format(flag, getattr(FLAGS, flag))) @@ -62,8 +65,8 @@ def main(_): with tf.device("/gpu:0"): """ Define Models """ - z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise') - real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images') + z = tf.placeholder(tf.float32, [None, z_dim], name='z_noise') + real_images = tf.placeholder(tf.float32, [None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images') # Input noise into generator for training net_g = generator(z, is_train=True, reuse=False) @@ -77,12 +80,11 @@ def main(_): net_g2 = generator(z, is_train=False, reuse=True) """ Define Training Operations """ - # cost for updating discriminator and generator # discriminator: real images are labelled as 1 d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal') - # discriminator: images from generator (fake) are labelled as 0 d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake') + # cost for updating discriminator d_loss = d_loss_real + d_loss_fake # generator: try to make the the fake images look real (1) @@ -112,17 +114,16 @@ def main(_): data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg"))) num_files = len(data_files) - shuffle = True # Mini-batch generator - def iterate_minibatches(): + def iterate_minibatches(batch_size, shuffle=True): if shuffle: indices = np.random.permutation(num_files) - for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size): + for start_idx in range(0, num_files - batch_size + 1, batch_size): if shuffle: - excerpt = indices[start_idx: start_idx + FLAGS.batch_size] + excerpt = indices[start_idx: start_idx + batch_size] else: - excerpt = slice(start_idx, start_idx + FLAGS.batch_size) + excerpt = slice(start_idx, start_idx + batch_size) # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html]) yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for file in data_files[excerpt]]).astype(np.float32) @@ -136,13 +137,13 @@ def iterate_minibatches(): iter_counter = 0 for epoch in range(FLAGS.epoch): - sample_images = next(iterate_minibatches()) + sample_images = next(iterate_minibatches(FLAGS.sample_size)) print("[*] Sample images updated!") steps = 0 - for batch_images in iterate_minibatches(): + for batch_images in iterate_minibatches(FLAGS.batch_size): - batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) + batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.batch_size, z_dim)).astype(np.float32) start_time = time.time() # Updates the Discriminator(D) @@ -162,7 +163,7 @@ def iterate_minibatches(): # Generate images img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images}) # Visualize generated images - tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps)) + tl.visualize.save_images(img, [num_tiles, num_tiles], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG)) if np.mod(iter_counter, FLAGS.save_step) == 0: @@ -171,10 +172,13 @@ def iterate_minibatches(): tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess) tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess) print("[*] Saving checkpoints SUCCESS!") - + steps += 1 - + sess.close() if __name__ == '__main__': - tf.app.run() + try: + tf.app.run() + except KeyboardInterrupt: + print('EXIT')