Skip to content

Commit

Permalink
batch training implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatbsk committed Apr 8, 2017
1 parent 5ebbaf7 commit c87007c
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 84 deletions.
38 changes: 0 additions & 38 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,9 @@

Adversarial training, first proposed by Ian Goodfellow in his [NIPS-2014 paper](https://arxiv.org/abs/1406.2661), is a way to train two neural networks simultaneously. The first one is the Discriminator, denoted as D(Y), takes an input (e.g. an image) and outputs a scalar indicating whether the image Y looks "natural" or not. The output of D(Y) can be a score turned into a probability using a softmax function. The probability is close to 1 if it's a face image and close to 0, otherwise. The second network is a Generator, denoted by G(Z), where Z is usually a vector randomly sampled in a simple distribution (e.g. Gaussian). The role of this generator is to generate a fake image so as to train D(Y) to output the correct probability. During training, D is shown a real image and it adjusts its parameters to output the correct probability. Then it is shown an image generated by G. D, again, adjusts its parameters to make its output D(G(Z)), large (following the gradient of a predefined function). But G(Z) will train itself to generate more natural looking images in order to fool D. It does this by taking the gradient of D w.r.t Y for each sample it produces.

*TODO:*
1. Implement batch training.
2. Accept gray images as input.

*Other GAN implementations:*

- https://github.com/goodfeli/adversarial: Theano GAN implementation released by the authors of the GAN paper.
- https://github.com/Newmu/dcgan_code: Theano DCGAN implementation released by the authors of the DCGAN paper.
- https://github.com/carpedm20/DCGAN-tensorflow: Unofficial TensorFlow DCGAN implementation.
- https://github.com/openai/improved-gan: Code behind OpenAI’s first paper.

#### Output:

The training is incomplete. Here's some results after 2000 steps. (I expect far better output after 10000 steps.)

![output](images/generated_image.jpg)

#### Dataset:

A dataset of face (old) images can be downloaded from this google drive link: https://drive.google.com/open?id=0B_uiWs-gNj4wbmF4em1hQ3BHQWM

#### Requirements:

```
sudo pip3 install numpy scipy tensorflow
```

#### Train:

```
python3 main.py --data_dir <path/to/data/directory>
```

#### Generate: (requires checkpoints)

```
python3 main.py --nois_train --latest_ckpt <checkpoint_index>
```

example,

```
python3 main.py --nois_train --latest_ckpt 2000
```
Binary file added data/indian_celebrities_male.tfrecords
Binary file not shown.
32 changes: 32 additions & 0 deletions helper/create_tfrecords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import skimage.io as io
import os
from glob import glob
import tensorflow as tf

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

image_files = glob(os.path.join('../data/indian_celebrities_male', '*.jpg'))
tfrecords_filename = '../data/indian_celebrities_male.tfrecords'

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for img_path in image_files:
img = io.imread(img_path)

height = img.shape[0]
width = img.shape[1]

img_raw = img.tostring()

example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'image_raw': _bytes_feature(img_raw)}))

writer.write(example.SerializeToString())

writer.close()
54 changes: 39 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,50 @@
from dcgan import DCGAN

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('log_dir', 'checkpoints', """Path to write logs and checkpoints""")
tf.app.flags.DEFINE_string('images_dir', 'images', """Path to save generated images""")
tf.app.flags.DEFINE_string('data_dir', 'data', """Path to data directory""")
tf.app.flags.DEFINE_integer('max_itr', 10000, """Maximum number of iterations""")
tf.app.flags.DEFINE_integer('latest_ckpt', 0, """Latest checkpoint timestamp to load""")
tf.app.flags.DEFINE_boolean('is_train', True, """False for generating only""")
tf.app.flags.DEFINE_boolean('is_grayscale', False, """True for grayscale images [not yet implemented]""")
tf.app.flags.DEFINE_string('log_dir', 'checkpoints', """Path to write logs and checkpoints""")
tf.app.flags.DEFINE_string('images_dir', 'images', """Path to save generated images""")
tf.app.flags.DEFINE_string('data_dir', 'data', """Path to data directory""")
tf.app.flags.DEFINE_integer('max_itr', 10000, """Maximum number of iterations""")
tf.app.flags.DEFINE_integer('latest_ckpt', 0, """Latest checkpoint timestamp to load""")
tf.app.flags.DEFINE_boolean('is_train', True, """False for generating only""")
tf.app.flags.DEFINE_boolean('is_grayscale', False, """True for grayscale images""")
tf.app.flags.DEFINE_integer('num_examples_per_epoch_for_train', 300, """number of examples for train""")

CROP_IMAGE_SIZE = 96

def read_decode(batch_size, f_size):
files = [os.path.join(FLAGS.data_dir, f) for f in os.listdir(FLAGS.data_dir) if f.endswith('.tfrecords')]
fqueue = tf.train.string_input_producer(files)
reader = tf.TFRecordReader()
_, serialized = reader.read(fqueue)
features = tf.parse_single_example(serialized, features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)})

image = tf.cast(tf.decode_raw(features['image_raw'], tf.uint8), tf.float32)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)

image = tf.reshape(image, [height, width, 3])
image = tf.image.resize_image_with_crop_or_pad(image, CROP_IMAGE_SIZE, CROP_IMAGE_SIZE)
#image = tf.image.random_flip_left_right(image)

min_queue_examples = FLAGS.num_examples_per_epoch_for_train
images = tf.train.shuffle_batch(
[image],
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
tf.summary.image('images', images)
return tf.subtract(tf.div(tf.image.resize_images(images, [f_size * 2 ** 4, f_size * 2 ** 4]), 127.5), 1.0)

def main(_):
dcgan = DCGAN(batch_size=128, f_size=6, z_dim=40,
dcgan = DCGAN(batch_size=4, f_size=6, z_dim=40,
gdepth1=216, gdepth2=144, gdepth3=96, gdepth4=64,
ddepth1=64, ddepth2=96, ddepth3=144, ddepth4=216)

"""
Batch training not implemented yet. Keep exactly 128 images inside data_dir.
"""
input_images = input_data(FLAGS.data_dir,
input_height=128, input_width=128,
resize_height=96, resize_width=96,
is_grayscale=FLAGS.is_grayscale)
input_images = read_decode(dcgan.batch_size, dcgan.f_size)

train_op = dcgan.build(input_images, feature_matching=True)

Expand Down
115 changes: 115 additions & 0 deletions notebooks/view_data.ipynb

Large diffs are not rendered by default.

31 changes: 0 additions & 31 deletions utils.py

This file was deleted.

0 comments on commit c87007c

Please sign in to comment.