-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
shahviraj
committed
Feb 26, 2018
1 parent
d9777a3
commit 1747434
Showing
17 changed files
with
1,561 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2018 Viraj Shah | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
## Solving Linear Inverse Problems using GANs | ||
|
||
Code for the paper: [Solving Linear Inverse Problems Using GAN Priors: An Algorithm with Provable Guarantees](https://arxiv.org/abs/1802.08406). | ||
|
||
#### Requirements: | ||
--- | ||
To run this code, you require Python 2.7, Tensorflow 1.0.1 (preferably with GPU support), Scipy and PyPNG. | ||
|
||
Pip installation can be done by ```$ pip install -r requirements.txt``` | ||
|
||
### Instructions | ||
--- | ||
|
||
1. Clone the repository, and run all the commands from the parent directory, ```pgdgan/```. | ||
|
||
2. Download the datasets with the script*: | ||
```shell | ||
$ ./setup/download_data.sh | ||
``` | ||
3. To train the DCGAN on celebA from scratch, please visit https://github.com/carpedm20/DCGAN-tensorflow, and follow the instructions. | ||
Else, pretrained GAN model is available, courtesy [Bora et al.](https://github.com/AshishBora/csgm) | ||
To download it, please run the following script*: | ||
```shell | ||
$ ./setup/download_models.sh | ||
``` | ||
Make sure the model is located at ```./models/celebA_64_64``` | ||
4. Run following to run the experiment: | ||
```shell | ||
$ python pgdgan.py | ||
``` | ||
You can also use the script available in ```./exp_scripts``` | ||
|
||
|
||
\* replicated from https://github.com/AshishBora/csgm . |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# pylint: disable = C0103, C0111, C0301, R0914 | ||
|
||
"""Model definitions for celebA | ||
This file is partially based on | ||
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/main.py | ||
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py | ||
They come with the following license: https://github.com/carpedm20/DCGAN-tensorflow/blob/master/LICENSE | ||
""" | ||
|
||
import tensorflow as tf | ||
import ops | ||
|
||
|
||
class Hparams(object): | ||
def __init__(self): | ||
self.c_dim = 3 | ||
self.z_dim = 100 | ||
self.gf_dim = 64 | ||
self.df_dim = 64 | ||
self.gfc_dim = 1024 | ||
self.dfc_dim = 1024 | ||
self.batch_size = 64 | ||
|
||
|
||
def generator(hparams, z, train, reuse): | ||
|
||
if reuse: | ||
tf.get_variable_scope().reuse_variables() | ||
|
||
output_size = 64 | ||
s = output_size | ||
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) | ||
|
||
g_bn0 = ops.batch_norm(name='g_bn0') | ||
g_bn1 = ops.batch_norm(name='g_bn1') | ||
g_bn2 = ops.batch_norm(name='g_bn2') | ||
g_bn3 = ops.batch_norm(name='g_bn3') | ||
|
||
# project `z` and reshape | ||
h0 = tf.reshape(ops.linear(z, hparams.gf_dim*8*s16*s16, 'g_h0_lin'), [-1, s16, s16, hparams.gf_dim * 8]) | ||
h0 = tf.nn.relu(g_bn0(h0, train=train)) | ||
|
||
h1 = ops.deconv2d(h0, [hparams.batch_size, s8, s8, hparams.gf_dim*4], name='g_h1') | ||
h1 = tf.nn.relu(g_bn1(h1, train=train)) | ||
|
||
h2 = ops.deconv2d(h1, [hparams.batch_size, s4, s4, hparams.gf_dim*2], name='g_h2') | ||
h2 = tf.nn.relu(g_bn2(h2, train=train)) | ||
|
||
h3 = ops.deconv2d(h2, [hparams.batch_size, s2, s2, hparams.gf_dim*1], name='g_h3') | ||
h3 = tf.nn.relu(g_bn3(h3, train=train)) | ||
|
||
h4 = ops.deconv2d(h3, [hparams.batch_size, s, s, hparams.c_dim], name='g_h4') | ||
x_gen = tf.nn.tanh(h4) | ||
|
||
return x_gen | ||
|
||
|
||
def discriminator(hparams, x, train, reuse): | ||
|
||
if reuse: | ||
tf.get_variable_scope().reuse_variables() | ||
|
||
d_bn1 = ops.batch_norm(name='d_bn1') | ||
d_bn2 = ops.batch_norm(name='d_bn2') | ||
d_bn3 = ops.batch_norm(name='d_bn3') | ||
|
||
h0 = ops.lrelu(ops.conv2d(x, hparams.df_dim, name='d_h0_conv')) | ||
|
||
h1 = ops.conv2d(h0, hparams.df_dim*2, name='d_h1_conv') | ||
h1 = ops.lrelu(d_bn1(h1, train=train)) | ||
|
||
h2 = ops.conv2d(h1, hparams.df_dim*4, name='d_h2_conv') | ||
h2 = ops.lrelu(d_bn2(h2, train=train)) | ||
|
||
h3 = ops.conv2d(h2, hparams.df_dim*8, name='d_h3_conv') | ||
h3 = ops.lrelu(d_bn3(h3, train=train)) | ||
|
||
h4 = ops.linear(tf.reshape(h3, [hparams.batch_size, -1]), 1, 'd_h3_lin') | ||
|
||
d_logit = h4 | ||
d = tf.nn.sigmoid(d_logit) | ||
|
||
return d, d_logit | ||
|
||
|
||
def gen_restore_vars(): | ||
restore_vars = ['g_bn0/beta', | ||
'g_bn0/gamma', | ||
'g_bn0/moving_mean', | ||
'g_bn0/moving_variance', | ||
'g_bn1/beta', | ||
'g_bn1/gamma', | ||
'g_bn1/moving_mean', | ||
'g_bn1/moving_variance', | ||
'g_bn2/beta', | ||
'g_bn2/gamma', | ||
'g_bn2/moving_mean', | ||
'g_bn2/moving_variance', | ||
'g_bn3/beta', | ||
'g_bn3/gamma', | ||
'g_bn3/moving_mean', | ||
'g_bn3/moving_variance', | ||
'g_h0_lin/Matrix', | ||
'g_h0_lin/bias', | ||
'g_h1/biases', | ||
'g_h1/w', | ||
'g_h2/biases', | ||
'g_h2/w', | ||
'g_h3/biases', | ||
'g_h3/w', | ||
'g_h4/biases', | ||
'g_h4/w'] | ||
return restore_vars | ||
|
||
|
||
|
||
def discrim_restore_vars(): | ||
restore_vars = ['d_bn1/beta', | ||
'd_bn1/gamma', | ||
'd_bn1/moving_mean', | ||
'd_bn1/moving_variance', | ||
'd_bn2/beta', | ||
'd_bn2/gamma', | ||
'd_bn2/moving_mean', | ||
'd_bn2/moving_variance', | ||
'd_bn3/beta', | ||
'd_bn3/gamma', | ||
'd_bn3/moving_mean', | ||
'd_bn3/moving_variance', | ||
'd_h0_conv/biases', | ||
'd_h0_conv/w', | ||
'd_h1_conv/biases', | ||
'd_h1_conv/w', | ||
'd_h2_conv/biases', | ||
'd_h2_conv/w', | ||
'd_h3_conv/biases', | ||
'd_h3_conv/w', | ||
'd_h3_lin/Matrix', | ||
'd_h3_lin/bias'] | ||
return restore_vars |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# pylint: disable = C0103, C0111, C0301, R0914 | ||
|
||
"""Model definitions for celebA | ||
This file is partially based on | ||
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/main.py | ||
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py | ||
They come with the following license: https://github.com/carpedm20/DCGAN-tensorflow/blob/master/LICENSE | ||
""" | ||
|
||
import tensorflow as tf | ||
import ops | ||
|
||
|
||
class Hparams(object): | ||
def __init__(self): | ||
self.c_dim = 3 | ||
self.z_dim = 100 | ||
self.gf_dim = 64 | ||
self.df_dim = 64 | ||
self.gfc_dim = 1024 | ||
self.dfc_dim = 1024 | ||
self.batch_size = 64 | ||
|
||
|
||
def generator(hparams, z, scope_name, train, reuse): | ||
|
||
with tf.variable_scope(scope_name) as scope: | ||
if reuse: | ||
scope.reuse_variables() | ||
|
||
output_size = 64 | ||
s = output_size | ||
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) | ||
|
||
g_bn0 = ops.batch_norm(name='g_bn0') | ||
g_bn1 = ops.batch_norm(name='g_bn1') | ||
g_bn2 = ops.batch_norm(name='g_bn2') | ||
g_bn3 = ops.batch_norm(name='g_bn3') | ||
|
||
# project `z` and reshape | ||
h0 = tf.reshape(ops.linear(z, hparams.gf_dim*8*s16*s16, 'g_h0_lin'), [-1, s16, s16, hparams.gf_dim * 8]) | ||
h0 = tf.nn.relu(g_bn0(h0, train=train)) | ||
|
||
h1 = ops.deconv2d(h0, [hparams.batch_size, s8, s8, hparams.gf_dim*4], name='g_h1') | ||
h1 = tf.nn.relu(g_bn1(h1, train=train)) | ||
|
||
h2 = ops.deconv2d(h1, [hparams.batch_size, s4, s4, hparams.gf_dim*2], name='g_h2') | ||
h2 = tf.nn.relu(g_bn2(h2, train=train)) | ||
|
||
h3 = ops.deconv2d(h2, [hparams.batch_size, s2, s2, hparams.gf_dim*1], name='g_h3') | ||
h3 = tf.nn.relu(g_bn3(h3, train=train)) | ||
|
||
h4 = ops.deconv2d(h3, [hparams.batch_size, s, s, hparams.c_dim], name='g_h4') | ||
x_gen = tf.nn.tanh(h4) | ||
|
||
return x_gen | ||
|
||
|
||
def discriminator(hparams, x, scope_name, train, reuse): | ||
|
||
with tf.variable_scope(scope_name) as scope: | ||
if reuse: | ||
scope.reuse_variables() | ||
|
||
d_bn1 = ops.batch_norm(name='d_bn1') | ||
d_bn2 = ops.batch_norm(name='d_bn2') | ||
d_bn3 = ops.batch_norm(name='d_bn3') | ||
|
||
h0 = ops.lrelu(ops.conv2d(x, hparams.df_dim, name='d_h0_conv')) | ||
|
||
h1 = ops.conv2d(h0, hparams.df_dim*2, name='d_h1_conv') | ||
h1 = ops.lrelu(d_bn1(h1, train=train)) | ||
|
||
h2 = ops.conv2d(h1, hparams.df_dim*4, name='d_h2_conv') | ||
h2 = ops.lrelu(d_bn2(h2, train=train)) | ||
|
||
h3 = ops.conv2d(h2, hparams.df_dim*8, name='d_h3_conv') | ||
h3 = ops.lrelu(d_bn3(h3, train=train)) | ||
|
||
h4 = ops.linear(tf.reshape(h3, [hparams.batch_size, -1]), 1, 'd_h3_lin') | ||
|
||
d_logit = h4 | ||
d = tf.nn.sigmoid(d_logit) | ||
|
||
return d, d_logit | ||
|
||
|
||
def gen_restore_vars(): | ||
restore_vars = ['gen/g_bn0/beta', | ||
'gen/g_bn0/gamma', | ||
'gen/g_bn0/moving_mean', | ||
'gen/g_bn0/moving_variance', | ||
'gen/g_bn1/beta', | ||
'gen/g_bn1/gamma', | ||
'gen/g_bn1/moving_mean', | ||
'gen/g_bn1/moving_variance', | ||
'gen/g_bn2/beta', | ||
'gen/g_bn2/gamma', | ||
'gen/g_bn2/moving_mean', | ||
'gen/g_bn2/moving_variance', | ||
'gen/g_bn3/beta', | ||
'gen/g_bn3/gamma', | ||
'gen/g_bn3/moving_mean', | ||
'gen/g_bn3/moving_variance', | ||
'gen/g_h0_lin/Matrix', | ||
'gen/g_h0_lin/bias', | ||
'gen/g_h1/biases', | ||
'gen/g_h1/w', | ||
'gen/g_h2/biases', | ||
'gen/g_h2/w', | ||
'gen/g_h3/biases', | ||
'gen/g_h3/w', | ||
'gen/g_h4/biases', | ||
'gen/g_h4/w'] | ||
return restore_vars |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Ops used in the DCGAN model | ||
File based on : https://github.com/carpedm20/DCGAN-tensorflow/blob/master/ops.py | ||
It comes with the following license: https://github.com/carpedm20/DCGAN-tensorflow/blob/master/LICENSE | ||
""" | ||
# pylint: disable = C0103, C0111, C0301, R0913, R0903 | ||
|
||
import tensorflow as tf | ||
|
||
class batch_norm(object): | ||
def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"): | ||
with tf.variable_scope(name): | ||
self.epsilon = epsilon | ||
self.momentum = momentum | ||
self.name = name | ||
|
||
def __call__(self, x, train=True): | ||
return tf.contrib.layers.batch_norm(x, | ||
decay=self.momentum, | ||
updates_collections=None, | ||
epsilon=self.epsilon, | ||
scale=True, | ||
is_training=train, | ||
scope=self.name) | ||
|
||
|
||
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): | ||
shape = input_.get_shape().as_list() | ||
|
||
with tf.variable_scope(scope or "Linear"): | ||
matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, | ||
tf.random_normal_initializer(stddev=stddev)) | ||
bias = tf.get_variable("bias", [output_size], | ||
initializer=tf.constant_initializer(bias_start)) | ||
if with_w: | ||
return tf.matmul(input_, matrix) + bias, matrix, bias | ||
else: | ||
return tf.matmul(input_, matrix) + bias | ||
|
||
|
||
def conv_cond_concat(x, y): | ||
"""Concatenate conditioning vector on feature map axis.""" | ||
x_shapes = x.get_shape() | ||
y_shapes = y.get_shape() | ||
return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) | ||
|
||
|
||
def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="deconv2d", with_w=False): | ||
with tf.variable_scope(name): | ||
# filter : [height, width, output_channels, in_channels] | ||
w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], | ||
initializer=tf.random_normal_initializer(stddev=stddev)) | ||
deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, | ||
strides=[1, d_h, d_w, 1]) | ||
biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) | ||
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) | ||
|
||
if with_w: | ||
return deconv, w, biases | ||
else: | ||
return deconv | ||
|
||
|
||
def lrelu(x, leak=0.2, name='lrelu'): | ||
return tf.maximum(x, leak*x, name=name) | ||
|
||
|
||
def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): | ||
with tf.variable_scope(name): | ||
w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], | ||
initializer=tf.truncated_normal_initializer(stddev=stddev)) | ||
conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') | ||
|
||
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) | ||
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) | ||
return conv |
Oops, something went wrong.