Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ dist
docs/_build
tensorlayer.egg-info
tensorlayer/__pacache__

.vscode/*
data/*
samples/*
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Looking for Text to Image Synthesis ? [click here](https://github.com/zsdonghao/
## Prerequisites

- Python 2.7 or Python 3.3+
- [TensorFlow==1.0+](https://www.tensorflow.org/)
- [TensorLayer==1.4+](https://github.com/zsdonghao/tensorlayer)
- [TensorFlow==1.10.0+](https://www.tensorflow.org/)
- [TensorLayer==1.10.1+](https://github.com/tensorlayer/tensorlayer)


## Usage
Expand Down
103 changes: 65 additions & 38 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import os, pprint, time
""" TensorLayer implementation of Deep Convolutional Generative Adversarial Network (DCGAN).
Using deep convolutional generative adversarial networks (DCGAN)
to generate face images from a noise distribution.
References:
-Generative Adversarial Nets.
Goodfellow et al. arXiv: 1406.2661.
- Unsupervised Representation Learning with Deep Convolutional
Generative Adversarial Networks. A Radford, L Metz, S Chintala.
arXiv: 1511.06434.
Links:
- [GAN Paper](https://arxiv.org/pdf/1406.2661.pdf)
- [DCGAN Paper](https://arxiv.org/abs/1511.06434)
Usage:
- See README.md
"""
import os
import time

import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *

from glob import glob
from random import shuffle
from model import *
from utils import *

pp = pprint.PrettyPrinter()

"""
TensorLayer implementation of DCGAN to generate face image.
from model import generator_simplified_api, discriminator_simplified_api
from utils import get_image

Usage : see README.md
"""
# Defile TF Flags
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The number of batch images [64]")
flags.DEFINE_integer("image_size", 108, "The size of image to use (will be center cropped) [108]")
flags.DEFINE_integer("output_size", 64, "The size of the output images to produce [64]")
Expand All @@ -36,57 +48,65 @@
FLAGS = flags.FLAGS

def main(_):
pp.pprint(flags.FLAGS.__flags)
# Print flags
for flag, _ in FLAGS.__flags.items():
print('"{}": {}'.format(flag, getattr(FLAGS, flag)))
print("--------------------")

# Configure checkpoint/samples dir
tl.files.exists_or_mkdir(FLAGS.checkpoint_dir)
tl.files.exists_or_mkdir(FLAGS.sample_dir)

z_dim = 100
z_dim = 100 # noise dim

# Construct graph on GPU
with tf.device("/gpu:0"):
##========================= DEFINE MODEL ===========================##

""" Define Models """
z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise')
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')

# z --> generator for training
# Input noise into generator for training
net_g, g_logits = generator_simplified_api(z, is_train=True, reuse=False)
# generated fake images --> discriminator

# Input real and generated fake images into discriminator for training
net_d, d_logits = discriminator_simplified_api(net_g.outputs, is_train=True, reuse=False)
# real images --> discriminator
net_d2, d2_logits = discriminator_simplified_api(real_images, is_train=True, reuse=True)
# sample_z --> generator for evaluation, set is_train to False
# so that BatchNormLayer behave differently

# Input noise into generator for evaluation
# set is_train to False so that BatchNormLayer behave differently
net_g2, g2_logits = generator_simplified_api(z, is_train=False, reuse=True)

##========================= DEFINE TRAIN OPS =======================##
""" Define Training Operations """
# cost for updating discriminator and generator
# discriminator: real images are labelled as 1
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')

# discriminator: images from generator (fake) are labelled as 0
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
d_loss = d_loss_real + d_loss_fake

# generator: try to make the the fake images look real (1)
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')

g_vars = tl.layers.get_variables_with_name('generator', True, True)
d_vars = tl.layers.get_variables_with_name('discriminator', True, True)

net_g.print_params(False)
print("---------------")
net_d.print_params(False)

# optimizers for updating discriminator and generator
# Define optimizers for updating discriminator and generator
d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
.minimize(d_loss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
.minimize(g_loss, var_list=g_vars)

# Init Session
sess = tf.InteractiveSession()
tl.layers.initialize_global_variables(sess)
sess.run(tf.global_variables_initializer())

model_dir = "%s_%s_%s" % (FLAGS.dataset, FLAGS.batch_size, FLAGS.output_size)
save_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
tl.files.exists_or_mkdir(FLAGS.sample_dir)
tl.files.exists_or_mkdir(save_dir)

# load the latest checkpoints
net_g_name = os.path.join(save_dir, 'net_g.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz')
Expand All @@ -95,50 +115,57 @@ def main(_):

sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)

##========================= TRAIN MODELS ================================##
""" Training models """
iter_counter = 0
for epoch in range(FLAGS.epoch):
## shuffle data

# Shuffle data
shuffle(data_files)

## update sample files based on shuffled data
# Update sample files based on shuffled data
sample_files = data_files[0:FLAGS.sample_size]
sample = [get_image(sample_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for sample_file in sample_files]
sample_images = np.array(sample).astype(np.float32)
print("[*] Sample images updated!")

## load image data
# Load image data
batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size

for idx in range(0, batch_idxs):
batch_files = data_files[idx*FLAGS.batch_size:(idx+1)*FLAGS.batch_size]
## get real images
# more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
batch_files = data_files[idx*FLAGS.batch_size:(idx + 1) * FLAGS.batch_size]

# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files]
batch_images = np.array(batch).astype(np.float32)
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
start_time = time.time()
# updates the discriminator

# Updates the Discriminator(D)
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images })
# updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)

# Updates the Generator(G)
# run generator twice to make sure that d_loss does not go to zero (different from paper)
for _ in range(2):
errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z})
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG))

iter_counter += 1
if np.mod(iter_counter, FLAGS.sample_step) == 0:
# generate and visualize generated images
# Generate images
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
# Visualize generated images
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))

if np.mod(iter_counter, FLAGS.save_step) == 0:
# save current network parameters
# Save current network parameters
print("[*] Saving checkpoints...")
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
print("[*] Saving checkpoints SUCCESS!")

sess.close()

if __name__ == '__main__':
tf.app.run()
41 changes: 23 additions & 18 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
from tensorlayer.layers import (
InputLayer,
DenseLayer,
DeConv2d,
ReshapeLayer,
BatchNormLayer,
Conv2d,
FlattenLayer
)

flags = tf.app.flags
FLAGS = flags.FLAGS
Expand All @@ -11,11 +18,10 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
batch_size = FLAGS.batch_size # 64
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("generator", reuse=reuse):
tl.layers.set_name_reuse(reuse)

net_in = InputLayer(inputs, name='g/in')
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
Expand All @@ -24,53 +30,52 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h0/batch_norm')

net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h1/batch_norm')

net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h2/batch_norm')

net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h3/decon2d')
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h3/batch_norm')

net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2),
padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
logits = net_h4.outputs
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
return net_h4, logits

def discriminator_simplified_api(inputs, is_train=True, reuse=False):
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
batch_size = FLAGS.batch_size # 64
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("discriminator", reuse=reuse):
tl.layers.set_name_reuse(reuse)

net_in = InputLayer(inputs, name='d/in')
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=tf.nn.leaky_relu,
padding='SAME', W_init=w_init, name='d/h0/conv2d')

net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
net_h1 = BatchNormLayer(net_h1, act=tf.nn.leaky_relu,
is_train=is_train, gamma_init=gamma_init, name='d/h1/batch_norm')

net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
net_h2 = BatchNormLayer(net_h2, act=tf.nn.leaky_relu,
is_train=is_train, gamma_init=gamma_init, name='d/h2/batch_norm')

net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h3/conv2d')
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
net_h3 = BatchNormLayer(net_h3, act=tf.nn.leaky_relu,
is_train=is_train, gamma_init=gamma_init, name='d/h3/batch_norm')

net_h4 = FlattenLayer(net_h3, name='d/h4/flatten')
Expand Down
29 changes: 0 additions & 29 deletions tensorlayer/__init__.py

This file was deleted.

Binary file removed tensorlayer/__pycache__/__init__.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/_logging.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/_logging.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/activation.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/activation.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/cost.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/cost.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/distributed.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/distributed.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/files.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/files.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/iterate.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/iterate.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/nlp.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/nlp.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/prepro.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/prepro.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/rein.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/rein.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/utils.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/utils.cpython-35.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/visualize.cpython-34.pyc
Binary file not shown.
Binary file removed tensorlayer/__pycache__/visualize.cpython-35.pyc
Binary file not shown.
16 changes: 0 additions & 16 deletions tensorlayer/_logging.py

This file was deleted.

Loading