Skip to content

Commit

Permalink
[release] TF1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Apr 11, 2017
1 parent fceb8cb commit abd96f4
Show file tree
Hide file tree
Showing 29 changed files with 4,016 additions and 1,191 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ TensorFlow / TensorLayer implementation of [Deep Convolutional Generative Advers
## Prerequisites

- Python 2.7 or Python 3.3+
- [TensorFlow==0.10.0 or higher](https://www.tensorflow.org/)
- [TensorLayer==1.2.6 or higher](https://github.com/zsdonghao/tensorlayer) (already in this repo)
- [TensorFlow==1.0+](https://www.tensorflow.org/)
- [TensorLayer==1.4+](https://github.com/zsdonghao/tensorlayer)


## Usage
Expand All @@ -25,4 +25,6 @@ To train a model with downloaded dataset:

$ python main.py

## Result

![alt tag](result.png)
353 changes: 56 additions & 297 deletions main.py

Large diffs are not rendered by default.

89 changes: 89 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *


flags = tf.app.flags
FLAGS = flags.FLAGS



def generator_simplified_api(inputs, is_train=True, reuse=False):
image_size = 64
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
batch_size = FLAGS.batch_size # 64

w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("generator", reuse=reuse):
tl.layers.set_name_reuse(reuse)

net_in = InputLayer(inputs, name='g/in')
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
act = tf.identity, name='g/h0/lin')
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h0/batch_norm')

net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h1/batch_norm')

net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h2/batch_norm')

net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h3/batch_norm')

net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
logits = net_h4.outputs
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
return net_h4, logits


def discriminator_simplified_api(inputs, is_train=True, reuse=False):
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
c_dim = FLAGS.c_dim # n_color 3
batch_size = FLAGS.batch_size # 64

w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)

with tf.variable_scope("discriminator", reuse=reuse):
tl.layers.set_name_reuse(reuse)

net_in = InputLayer(inputs, name='d/in')
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
padding='SAME', W_init=w_init, name='d/h0/conv2d')

net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d/h1/batch_norm')

net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d/h2/batch_norm')

net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='d/h3/conv2d')
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='d/h3/batch_norm')

net_h4 = FlattenLayer(net_h3, name='d/h4/flatten')
net_h4 = DenseLayer(net_h4, n_units=1, act=tf.identity,
W_init = w_init, name='d/h4/lin_sigmoid')
logits = net_h4.outputs
net_h4.outputs = tf.nn.sigmoid(net_h4.outputs)
return net_h4, logits
5 changes: 4 additions & 1 deletion tensorlayer/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@
from . import rein


__version__ = "1.2.3"
__version__ = "1.4.2"

global_flag = {}
global_dict = {}
Binary file added tensorlayer/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/activation.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/cost.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/files.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/iterate.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/layers.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/nlp.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/ops.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/prepro.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/rein.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/utils.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/visualize.cpython-35.pyc
Binary file not shown.
18 changes: 10 additions & 8 deletions tensorlayer/activation.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def pixel_wise_softmax(output, name='pixel_wise_softmax'):
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`_
"""
with tf.name_scope(name) as scope:
exp_map = tf.exp(output)
if output.get_shape().ndims == 4: # 2d image
evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
elif output.get_shape().ndims == 5: # 3d image
evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
else:
raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
return tf.div(exp_map, evidence)
return tf.nn.softmax(output)
## old implementation
# exp_map = tf.exp(output)
# if output.get_shape().ndims == 4: # 2d image
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
# elif output.get_shape().ndims == 5: # 3d image
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
# else:
# raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
# return tf.div(exp_map, evidence)
Loading

0 comments on commit abd96f4

Please sign in to comment.