In [1]:
# VANILLA GAN ON TENSORFLOW
### This is my implementation of a generative adversarial network on tensorflow

In [2]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)

  from ._conv import register_converters as _register_converters


Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz


## Architecture:
The GAN has a generator and a discriminator:
- The generator has two dense layers, and it outputs an image from a vector of uniform noise.
- The discriminator has two dense layers and it discriminates again images coming from the mnist database or those generated by the generator

In [3]:
def sample_z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

class Gan(object):
    DEFAULTS = {
    "batch_size": 128,
    "learning_rate": 5E-4,
    "dropout": 0.9,
    "lambda_l2_reg": 1E-5,
    "nonlinearity": tf.nn.elu,
    "squashing": tf.nn.sigmoid,
    "regularization": tf.contrib.layers.l2_regularizer,
    "mu": 0,
    "sigma": 1.,
    "initializer": tf.contrib.layers.xavier_initializer(),
        "log_dir": './log_conv_gan'
    }
    RESTORE_KEY = "to_restore"


    def __init__(self, architecture=[], d_hyperparams={}, log_dir='./log_conv_gan'):
        self.architecture = architecture
        self.__dict__.update(Gan.DEFAULTS, **d_hyperparams)
        self.log_dir = self.log_dir
        self.sesh = tf.Session()
        #TODO: decide if load a model or build a new one. For now, build it
        handles = self._build_graph()

      # In any case, make a collection of variables that are restore-able
        for handle in handles:
            tf.add_to_collection(Gan.RESTORE_KEY, handle)

        # initialize all globals of the session
        self.sesh.run(tf.global_variables_initializer())


        # Unpack the tuple of handles created by the builder
        (self.z, self.x_real, self.g_sample, self.D_loss, self.G_loss, 
             self.D_solver, self.G_solver, self.merged_summaries, self.global_step) = handles

        # Initialize the filewriter and write the graph (tensorboard)
        self.writer = tf.summary.FileWriter(log_dir, self.sesh.graph)

  
    def discriminator(self, x_in):
        # Discriminator gets a batch of images and returns a batch of probabilities.
        # The output has Identity activation_fn, so that the entropy in the cost
        # functions shall be computed with logits
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            img_dim = self.architecture[1]
            x = tf.reshape(x_in, [-1, img_dim, img_dim, 1], name='d_in')
            x = tf.layers.conv2d(x, filters=32, kernel_size=[5, 5], 
                                 activation=self.nonlinearity,
                                 kernel_initializer=self.initializer,
                                 name='d_1'
                                )
            pool = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2, name='d_1_pool')
            
            x = tf.layers.conv2d(pool, filters=64, kernel_size=[5, 5],
                                activation=self.nonlinearity,
                                 kernel_initializer=self.initializer,
                                 name='d_2'
                                )
            pool = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2, name='d_2_pool')
            pool_flat = tf.reshape(pool, [-1, 1024], name='d_2_flatten')
            x = tf.layers.dense(inputs=pool_flat, units=1024, 
                                activation=self.nonlinearity,
                               kernel_initializer=self.initializer,
                               name='d_3')
            
            #dropout = tf.layers.dropout(x, rate=self.dropout, training=mode==tf.estimator.ModeKeys.TRAIN)
            
            d_logit = tf.layers.dense(inputs=x, units=1, 
                                activation=None,
                               kernel_initializer=self.initializer,
                               name='d_logit')
            return x, d_logit
    
    def generator(self, z):
        # Generator net gets a batch of noise vectors and returns a batch of images
        with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
            x = tf.reshape(z, [-1, 10, 10, 1])
            x = tf.layers.conv2d_transpose(inputs=x, filters=8,
                                            kernel_size=[5, 5],
                                            activation=self.nonlinearity,
                                            kernel_initializer=self.initializer,
                                            name='g_conv_1')
            x = tf.layers.conv2d_transpose(inputs=x, filters=8,
                                            kernel_size=[5, 5],
                                            activation=self.nonlinearity,
                                            kernel_initializer=self.initializer,
                                            )
            g = tf.contrib.layers.fully_connected(x, 784, activation_fn=self.squashing,
                              weights_initializer=self.initializer)
            
            return g
      
    def generate(self, z):
        return self.sesh.run([self.g_sample], feed_dict={self.z: z})
    
    
    def _build_graph(self):
        z = tf.placeholder(tf.float32, shape=[None, self.architecture[0]], name='z')
        x_real = tf.placeholder(tf.float32, shape=[None, 784], name='x_real')

        # These nodes are the output of the generator/discrimiator for the batch
        # we will use them to compute the losses
        d_real, d_logit_real = self.discriminator(x_real)
        g_sample = self.generator(z)
        d_fake, d_logit_fake = self.discriminator(g_sample)

        # The losses:
        # Intuitively, we want:
        
        #the DISCRIMINATOR to be good at predicting:
        # a) That the real images are real (hence the cross entropy against 1)
        D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_real, 
                                                                            labels=tf.ones_like(d_logit_real)),
                                    name='D_loss_real')
        # b) That the fake images are NOT real (hence the cross entropy agains 0)
        D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_fake,
                                                                            labels=tf.zeros_like(d_logit_fake)),
                                    name='D_loss_fake')
        D_loss = tf.add(D_loss_real, D_loss_fake, name='D_loss')
        
        # the Generator to be good at generating images that the discriminator (mis)predicts are real
        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_fake,
                                                                       labels=tf.ones_like(d_logit_fake)),
                               name='G_loss')

                
        # The optimizers (simpler but darker than compute; clip, apply)
        # There are two optimizers, each running on a group of variables while the others are constant.
        
        # Since they will all be run the same number of times for each minibatch, only one of them 
        # gets to advance the global_step
        
        global_step = tf.Variable(0, name='global_step', trainable=False)
        
        # Discriminator minimize grads. This one carries the global step
        # The solver only has to 'train' (modify) the TRAINABLE variables within the scope of
        # 'discriminator'
        D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')                
        D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_vars, 
                                                     global_step=global_step)

        # Generator minimize grads
        # The solver only has to 'train' (modify) the TRAINABLE variables within the scope of
        # 'generator'
        G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_vars)

        
        # append to summary
        with tf.name_scope('summaries'):
            with tf.name_scope('loss'):
                tf.summary.scalar('D_fake_loss', tf.reduce_sum(D_loss_fake, name='Dl_fake_scalar'))
                tf.summary.scalar('D_real_loss', tf.reduce_sum(D_loss_real, name='Dl_real_scalar'))
                tf.summary.scalar('D_loss', tf.reduce_sum(D_loss, name='Dl_scalar'))
                tf.summary.scalar('G_loss', tf.reduce_sum(G_loss, name='Gl_scalar'))
        summaries = tf.summary.merge_all()
        
        # Return nodes that the model needs to access
        return z, x_real, g_sample, D_loss, G_loss, D_solver, G_solver, summaries, global_step
    
    def train(self, X, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=True):
        try:
            while True:
                x_real, _ = X.train.next_batch(self.batch_size)
                z = sample_z(self.batch_size, self.architecture[0])

                # run a step of the optimizer training for the discriminator,
                # with generator fixed
                D_solver, D_loss, i = self.sesh.run([self.D_solver, self.D_loss, self.global_step], 
                                            feed_dict = {self.x_real: x_real, self.z: z})

                # run a step of the optimizer training for the generator,
                # with discriminator fixed
                summary, G_solver, G_loss = self.sesh.run([self.merged_summaries, self.G_solver, self.G_loss], 
                                            feed_dict = {self.x_real: x_real, self.z: z})

                # output, logging and interruption arithmetics
                
                if i % 10 == 0:  # Record summaries and test-set accuracy
                    self.writer.add_summary(summary, i)
                if i%1000 ==0  and verbose:
                    print('Round {}, g_loss {}, d_loss {}'.format(i, G_loss, D_loss))
                    samples = self.generate(sample_z(16, self.architecture[0]))
                    fig = plot(samples[0])
                    fig_fn = os.path.join(os.path.abspath(self.log_dir), '{:03d}.png'.format(i))
                    plt.savefig(fig_fn, bbox_inches='tight')
                    plt.show()

                if i>=max_iter or X.train.epochs_completed >= max_epochs:
                    print("final avg cost (@ step {} = epoch {})".format(
                        i, X.train.epochs_completed))
                    try:
                        self.writer.flush()
                        self.writer.close()
                    except(AttributeError):  # not logging
                        continue
                    break
                    
        except KeyboardInterrupt:
            print('Ended')

  
                    
z_dim = 100

gan = Gan(architecture=[z_dim, 28])

In [4]:
gan.train(mnist, max_epochs=10000)

ResourceExhaustedError: OOM when allocating tensor with shape[41472,32,24,24]
	 [[Node: gradients/discriminator_1/d_1_pool/MaxPool_grad/MaxPoolGrad = MaxPoolGrad[T=DT_FLOAT, data_format="NHWC", ksize=[1, 2, 2, 1], padding="VALID", strides=[1, 2, 2, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](discriminator_1/d_1/Elu, discriminator_1/d_1_pool/MaxPool, gradients/discriminator_1/d_2/Conv2D_grad/tuple/control_dependency)]]
	 [[Node: Adam/update/_36 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_720_Adam/update", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'gradients/discriminator_1/d_1_pool/MaxPool_grad/MaxPoolGrad', defined at:
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2683, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2787, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2847, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-7003457c89b1>", line 227, in <module>
    gan = Gan(architecture=[z_dim, 28])
  File "<ipython-input-3-7003457c89b1>", line 42, in __init__
    handles = self._build_graph()
  File "<ipython-input-3-7003457c89b1>", line 161, in _build_graph
    global_step=global_step)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/optimizer.py", line 343, in minimize
    grad_loss=grad_loss)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/optimizer.py", line 414, in compute_gradients
    colocate_gradients_with_ops=colocate_gradients_with_ops)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 581, in gradients
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 353, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 581, in <lambda>
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/nn_grad.py", line 555, in _MaxPoolGrad
    data_format=op.get_attr("data_format"))
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 3083, in _max_pool_grad
    data_format=data_format, name=name)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

...which was originally created as op 'discriminator_1/d_1_pool/MaxPool', defined at:
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
[elided 20 identical lines from previous traceback]
  File "<ipython-input-3-7003457c89b1>", line 42, in __init__
    handles = self._build_graph()
  File "<ipython-input-3-7003457c89b1>", line 126, in _build_graph
    d_fake, d_logit_fake = self.discriminator(g_sample)
  File "<ipython-input-3-7003457c89b1>", line 72, in discriminator
    pool = tf.layers.max_pooling2d(inputs=x, pool_size=[2, 2], strides=2, name='d_1_pool')
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/layers/pooling.py", line 416, in max_pooling2d
    return layer.apply(inputs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 671, in apply
    return self.__call__(inputs, *args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 575, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/layers/pooling.py", line 266, in call
    data_format=utils.convert_data_format(self.data_format, 4))
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py", line 1958, in max_pool
    name=name)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 2806, in _max_pool
    data_format=data_format, name=name)
  File "/home/earneodo/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[41472,32,24,24]
	 [[Node: gradients/discriminator_1/d_1_pool/MaxPool_grad/MaxPoolGrad = MaxPoolGrad[T=DT_FLOAT, data_format="NHWC", ksize=[1, 2, 2, 1], padding="VALID", strides=[1, 2, 2, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](discriminator_1/d_1/Elu, discriminator_1/d_1_pool/MaxPool, gradients/discriminator_1/d_2/Conv2D_grad/tuple/control_dependency)]]
	 [[Node: Adam/update/_36 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_720_Adam/update", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]


In [None]:
(mode==tf.estimator.ModeKeys.TRAIN)

In [None]:
### Use the trained gan to generate a batch of images

In [None]:
samples = gan.generate(sample_z(16, gan.architecture[0]))
plot(samples[0])

In [None]:
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
D_vars

In [None]:
x_real, _ = mnist.train.next_batch(gan.batch_size)
z = sample_z(gan.batch_size, gan.architecture[0])

D_solver, D_loss, i = gan.sesh.run([gan.D_solver, gan.D_loss, gan.global_step], 
                            feed_dict = {gan.x_real: x_real, gan.z: z})