In [None]:
import tensorflow as tf

#### this is a sample walk-through on the GAN of MNIST: hand-written digits

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets("../03-Convolutional-Neural-Networks/MNIST_data/", one_hot=True)

In [None]:
plt.imshow(mnist.train.images[12].reshape(28, 28), cmap='Greys')

In [None]:
# generator function
def generator(z, reuse=None):
    with tf.variable_scope('gen', reuse=reuse):
        # allow to have subsets of variables such that different values
        # of the same variable can be used in different layers of networks
        hidden1 = tf.layers.dense(inputs-z, units=128)
        # use a leaky relu as activation function
        alpha = 0.01
        hidden1 = tf.maximum(alpha*hidden1, hidden1)
        
        hidden2 = tf.layers.dense(inputs=hidden1, units=128)
        hidden2 = tf.maximum(alpha*hidden2, hidden2)
        
        output = tf.layers.dense(hidden2, units=784, activation = tf.nn.tanh)
        
        return output

In [None]:
# discriminator function
def discriminator(X, reuse=None):
    with tf.variable_scope('dis', reuse=reuse):
        hidden1 = tf.layers.dense(inputs=X, units=128)
        
        alpha=0.01
        hidden1 = tf.maximum(alpha*hidden1, hidden1)
        
        hidden2 = tf.maximum(alpha*hidden2, hidden2)
        
        logits = tf.layers.dense(hidden2, units=1)
        
        logits = tf.layers.dense(hidden2, units=1)
        output = tf.sigmoid(logits)
        
        return output, logits

In [None]:
# placeholders
real_images = tf.placeholder(tf.float32, shape=[None, 784])
z = tf.placeholder(tf.float32, shape=[None, 100])

In [None]:
G = generator(z)

In [None]:
D_output_real, D_logits_real = discriminator(real_images)

In [None]:
D_output_fake, D_logits_fake = discriminator(G, reuse=True)

In [None]:
# losses

def loss_func(logits_in, labels_in):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits = logits_in, labels=labels_in))

In [None]:
D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real)*0.9)

D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real))

In [None]:
D_loss = D_real_loss + D_fake_loss

In [None]:
G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake))

In [None]:
learning_rate = 0.001

In [None]:
tvars = tf.trainable_variables()

d_vars = [var for var in tvars if 'dis' in var.name]
g_vars = [var for var in tvars if 'gen' in var.name]

In [None]:
D_trainer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss,
                                                          var_list = d_vars)

In [None]:
G_trainer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, 
                                                           var_list = g_vars)

In [None]:
d_vars

In [None]:
batch_size = 100

In [None]:
epochs = 30

In [None]:
init = tf.global_variables_initializer()

In [None]:
samples = []

In [None]:
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(epochs):
        num_batches = mnist.train.num_examples // batch_size
        for i in range(num_batches):
            batch = mnist.trian.next_batch(batch_size)
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2-1
            # create some noise
            batch_z = np.random.uniform(-1, 1, size = (batch_size, 100))
            _ = sess.run(D_trainer, feed_dict = {real_images: batch_images, z: batch_z})
            _ = sess.run(G_trainer, feed_dict = {z: batch_z})
            
        print("ON EPOCH {}".format(epoch))
        
        sample_z = np.random.uniform(-1, 1, size=(1, 100))
        gen_sample = sess.run(generator(z, reuse = True), feed_dict={z: sample_z})
        
        samples_z.append(gen_sample)
        

In [None]:
plt.imshow(samples[0].reshape(28, 28))

In [None]:
plt.imshow(samples[29].reshape(28, 28))

In [None]:
saver = tf.train.Saver(var_list=g_vars)

In [None]:
new_samples = []
with tf.Session() as sess:
    saver.restore(sess, './models/500_epoch_models.ckpt')
    
    for x in range(5):
        sample_z = np.random.uniform(-1, 1, size=(1, 100))
        gen_sample = sess.run(generator(z, reuse = True), feed_dict={z: sample_z})
        
        new_samples.append(gen_sample)

In [None]:
plt.imshow(new_samples[].reshape(28, 28))