In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

In [None]:
class GAN():
    def __init__(self, generator_model = None , discriminator_model = None,
                 training_rate =1e-4):
        self.generator_model = generator_model
        self.discriminator_model = discriminator_model
        self.training_rate = training_rate
        
        self._create_placeholder()
        self._create_train_graph()
    
    def sample_Z(self,m, n):
        return np.random.uniform(-1., 1., size=[m, n])
    
    def _create_placeholder(self):
        self.X = tf.placeholder(tf.float32, shape=[None, 784])
        self.Z = tf.placeholder(tf.float32, shape=[None, 100])
    
    def _create_train_graph(self):
        G_sample = self.generator_model.generator(self.Z)
        D_logit_real = self.discriminator_model.discriminator(self.X)
        D_logit_fake = self.discriminator_model.discriminator(G_sample)
        
        D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
        D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
        D_loss = D_loss_real + D_loss_fake
        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

        D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=self.discriminator_model.theta_list)
        G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=self.generator_model.theta_list)
     
        self.G_sample = G_sample
        self.D_solver=D_solver
        self.G_solver=G_solver
        self.D_loss = D_loss
        self.G_loss = G_loss

In [None]:
mnist = tf.contrib.learn.datasets.mnist.read_data_sets(train_dir='mnist/data', one_hot=True)
mb_size = 128
Z_dim = 100

gan = GAN()


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for it in range(1000000):
        if it % 1000 == 0:
            samples = sess.run(gan.G_sample, feed_dict={gan.Z: gan.sample_Z(16, Z_dim)})
            fig = plot(samples)
            plt.show()
            plt.close(fig)

        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr = sess.run([gan.D_solver, gan.D_loss], feed_dict={gan.X: X_mb, gan.Z: gan.sample_Z(mb_size, Z_dim)})
        _, G_loss_curr = sess.run([gan.G_solver, gan.G_loss], feed_dict={gan.Z: gan.sample_Z(mb_size, Z_dim)})
            
        if it % 1000 == 0:
            print('Iter: {}'.format(it))
            print('D loss: {:.4}'. format(D_loss_curr))
            print('G_loss: {:.4}'.format(G_loss_curr))
            print()