diff --git a/main.py b/main.py index b7e6985..6cd8fd1 100755 --- a/main.py +++ b/main.py @@ -21,12 +21,11 @@ import tensorlayer as tl from glob import glob -from random import shuffle -from model import generator_simplified_api, discriminator_simplified_api from utils import get_image +from model import generator, discriminator -# Defile TF Flags +# Define TF Flags flags = tf.app.flags flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") @@ -67,15 +66,15 @@ def main(_): real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images') # Input noise into generator for training - net_g, g_logits = generator_simplified_api(z, is_train=True, reuse=False) + net_g = generator(z, is_train=True, reuse=False) # Input real and generated fake images into discriminator for training - net_d, d_logits = discriminator_simplified_api(net_g.outputs, is_train=True, reuse=False) - net_d2, d2_logits = discriminator_simplified_api(real_images, is_train=True, reuse=True) + net_d, d_logits = discriminator(net_g.outputs, is_train=True, reuse=False) + _, d2_logits = discriminator(real_images, is_train=True, reuse=True) # Input noise into generator for evaluation # set is_train to False so that BatchNormLayer behave differently - net_g2, g2_logits = generator_simplified_api(z, is_train=False, reuse=True) + net_g2 = generator(z, is_train=False, reuse=True) """ Define Training Operations """ # cost for updating discriminator and generator @@ -111,51 +110,59 @@ def main(_): net_g_name = os.path.join(save_dir, 'net_g.npz') net_d_name = os.path.join(save_dir, 'net_d.npz') - data_files = glob(os.path.join("./data", FLAGS.dataset, "*.jpg")) + data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg"))) + num_files = len(data_files) + shuffle = True + + # Mini-batch generator + def iterate_minibatches(): + if shuffle: + indices = np.random.permutation(num_files) + for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size): + if shuffle: + excerpt = indices[start_idx: start_idx + FLAGS.batch_size] + else: + excerpt = slice(start_idx, start_idx + FLAGS.batch_size) + # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html]) + yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) + for file in data_files[excerpt]]).astype(np.float32) + + batch_steps = min(num_files, FLAGS.train_size) // FLAGS.batch_size - sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32) + # sample noise + sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) """ Training models """ iter_counter = 0 for epoch in range(FLAGS.epoch): - # Shuffle data - shuffle(data_files) - - # Update sample files based on shuffled data - sample_files = data_files[0:FLAGS.sample_size] - sample = [get_image(sample_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for sample_file in sample_files] - sample_images = np.array(sample).astype(np.float32) + sample_images = next(iterate_minibatches()) print("[*] Sample images updated!") + + steps = 0 + for batch_images in iterate_minibatches(): - # Load image data - batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size - - for idx in range(0, batch_idxs): - batch_files = data_files[idx*FLAGS.batch_size:(idx + 1) * FLAGS.batch_size] - - # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html]) - batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files] - batch_images = np.array(batch).astype(np.float32) batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) start_time = time.time() # Updates the Discriminator(D) - errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images }) + errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images}) # Updates the Generator(G) # run generator twice to make sure that d_loss does not go to zero (different from paper) for _ in range(2): errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z}) + + end_time = time.time() - start_time print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ - % (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG)) + % (epoch, FLAGS.epoch, steps, batch_steps, end_time, errD, errG)) iter_counter += 1 if np.mod(iter_counter, FLAGS.sample_step) == 0: # Generate images - img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images}) + img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images}) # Visualize generated images - tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx)) + tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG)) if np.mod(iter_counter, FLAGS.save_step) == 0: @@ -164,6 +171,8 @@ def main(_): tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess) tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess) print("[*] Saving checkpoints SUCCESS!") + + steps += 1 sess.close() diff --git a/model.py b/model.py index 5970d80..2f839bc 100755 --- a/model.py +++ b/model.py @@ -13,29 +13,29 @@ flags = tf.app.flags FLAGS = flags.FLAGS -def generator_simplified_api(inputs, is_train=True, reuse=False): +def generator(inputs, is_train=True, reuse=False): image_size = 64 - s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16) - gf_dim = 64 # Dimension of gen filters in first conv layer. [64] - c_dim = FLAGS.c_dim # n_color 3 - w_init = tf.random_normal_initializer(stddev=0.02) + s16 = image_size // 16 + gf_dim = 64 # Dimension of gen filters in first conv layer. [64] + c_dim = FLAGS.c_dim # n_color 3 + w_init = tf.glorot_normal_initializer() gamma_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope("generator", reuse=reuse): net_in = InputLayer(inputs, name='g/in') - net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init, + net_h0 = DenseLayer(net_in, n_units=(gf_dim * 8 * s16 * s16), W_init=w_init, act = tf.identity, name='g/h0/lin') net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape') net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='g/h0/batch_norm') - net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), strides=(2, 2), + net_h1 = DeConv2d(net_h0, gf_dim * 4, (5, 5), strides=(2, 2), padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d') net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='g/h1/batch_norm') - net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), strides=(2, 2), + net_h2 = DeConv2d(net_h1, gf_dim * 2, (5, 5), strides=(2, 2), padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d') net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='g/h2/batch_norm') @@ -47,14 +47,12 @@ def generator_simplified_api(inputs, is_train=True, reuse=False): net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2), padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d') - logits = net_h4.outputs net_h4.outputs = tf.nn.tanh(net_h4.outputs) - return net_h4, logits + return net_h4 -def discriminator_simplified_api(inputs, is_train=True, reuse=False): - df_dim = 64 # Dimension of discrim filters in first conv layer. [64] - c_dim = FLAGS.c_dim # n_color 3 - w_init = tf.random_normal_initializer(stddev=0.02) +def discriminator(inputs, is_train=True, reuse=False): + df_dim = 64 # Dimension of discrim filters in first conv layer. [64] + w_init = tf.glorot_normal_initializer() gamma_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope("discriminator", reuse=reuse): diff --git a/utils.py b/utils.py index db13e9d..7601bf5 100755 --- a/utils.py +++ b/utils.py @@ -1,6 +1,5 @@ -from random import shuffle - import scipy.misc +import imageio as io import numpy as np def center_crop(x, crop_h, crop_w=None, resize_w=64): @@ -18,7 +17,7 @@ def merge(images, size): for idx, image in enumerate(images): i = idx % size[1] j = idx // size[1] - img[j*h:j*h+h, i*w:i*w+w, :] = image + img[j * h: j * h + h, i * w: i * w + w, :] = image return img def transform(image, npx=64, is_crop=True, resize_w=64): @@ -26,19 +25,19 @@ def transform(image, npx=64, is_crop=True, resize_w=64): cropped_image = center_crop(image, npx, resize_w=resize_w) else: cropped_image = image - return np.array(cropped_image)/127.5 - 1. + return (np.array(cropped_image) / 127.5) - 1. def inverse_transform(images): - return (images+1.)/2. + return (images + 1.) / 2. def imread(path, is_grayscale = False): if (is_grayscale): - return scipy.misc.imread(path, flatten = True).astype(np.float) + return io.imread(path).astype(np.float).flatten() else: - return scipy.misc.imread(path).astype(np.float) + return io.imread(path).astype(np.float) def imsave(images, size, path): - return scipy.misc.imsave(path, merge(images, size)) + return io.imsave(path, merge(images, size)) def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False): return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)