# Simple GAN

- theme setup

In [25]:
from IPython.core.display import HTML
def css():
    style = open("css/custom.css", "r").read()
    return HTML(style)
css()

- imports

In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xav

In [2]:
import numpy as np

## Graph

In [3]:
# clear graph
tf.reset_default_graph()

- Placeholders

In [4]:
x_ = tf.placeholder(tf.float32, shape=[None, 784], name='x')
z_ = tf.placeholder(tf.float32, shape=[None, 100], name='z')

- discriminator params

In [5]:
dw1 = tf.get_variable('dw1', [784, 128], tf.float32, initializer=xav())
dw2 = tf.get_variable('dw2', [128, 1], tf.float32, initializer=xav())
db1 = tf.get_variable('db1', [128], tf.float32, initializer=tf.constant_initializer(0.))
db2 = tf.get_variable('db2', [1], tf.float32, initializer=tf.constant_initializer(0.))

- generator params

In [6]:
gw1 = tf.get_variable('gw1', [100, 128], tf.float32, initializer=xav())
gw2 = tf.get_variable('gw2', [128, 784], tf.float32, initializer=xav())
gb1 = tf.get_variable('gb1', [128], tf.float32, initializer=tf.constant_initializer(0.))
gb2 = tf.get_variable('gb2', [784], tf.float32, initializer=tf.constant_initializer(0.))

- group params

In [7]:
theta_d, theta_g = [dw1, dw2, db1, db2], [gw1, gw2, gb1, gb2]

- build disciminator net

In [8]:
def D(x):
    dh1 = tf.nn.relu(tf.matmul(x, dw1) + db1)
    logit = tf.matmul(dh1, dw2) + db2
    prob = tf.nn.sigmoid(logit)
    
    return prob, logit

- build generator net

In [9]:
def G(z):
    gh1 = tf.nn.relu(tf.matmul(z, gw1) + gb1)
    log_prob = tf.matmul(gh1, gw2) + gb2
    prob = tf.nn.sigmoid(log_prob)
    
    return prob

- sample from generator

In [10]:
g_sample = G(z_)

- run D on real images

In [11]:
real_prob, real_logit = D(x_)

- run D on fake images (sampled from G)

In [12]:
fake_prob, fake_logit = D(g_sample)

### Optimization

- real loss
    - probability ground truth = 1
    - predicted probability = real_prob
    - loss : cross entropy between ground truth and prediction

In [13]:
d_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logit, 
                                            labels=tf.ones_like(real_logit)))

- fake loss
    - ground truth = 0
    - prediction = fake_prob

In [14]:
d_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit,
                                            labels=tf.zeros_like(fake_logit)))

- discriminator loss, d_loss = real loss + fake loss

In [15]:
d_loss = d_loss_real + d_loss_fake

- generator loss
    - ground truth = 1 (generated image should be realistic)
    - prediction : discriminator's prediction on how real it is 

In [16]:
g_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, 
                                            labels=tf.ones_like(fake_logit)))

- training operations

In [17]:
d_train_op = tf.train.AdamOptimizer().minimize(d_loss, var_list=theta_d)
g_train_op = tf.train.AdamOptimizer().minimize(g_loss, var_list=theta_g)

### MNIST

In [18]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/')

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


## Training

- sample noise

In [19]:
sample_z = lambda n : np.random.uniform(-1., 1., size=[n, 100])

- plot function

In [20]:
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec

In [21]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [22]:
def save_samples(samples, i):
    fig = plot(samples)
    plt.savefig('out/simple-gan/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
    i += 1
    plt.close(fig)

- create session

In [23]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

- begin training loop

In [24]:
iterations = 1000000
batch_size = 128
for i in range(iterations):
    if i%1000 == 0:
        samples = sess.run(g_sample, feed_dict= {
            z_ : sample_z(16)
        })
        # save to file
        save_samples(samples, i)
    
    images, _ = mnist.train.next_batch(batch_size)
    
    _, d_loss_value = sess.run([d_train_op, d_loss], feed_dict={
        x_ : images,
        z_ : sample_z(batch_size)
    })
    
    _, g_loss_value = sess.run([g_train_op, g_loss], feed_dict={
        z_ : sample_z(batch_size)
    })
    
    if i%1000 == 0:
        print('at {} iteration, d_loss : [{}], g_loss : [{}]'.
             format(i, d_loss_value, g_loss_value))

at 0 iteration, d_loss : [1.5482125282287598], g_loss : [2.984132766723633]
at 1000 iteration, d_loss : [0.00873700249940157], g_loss : [8.086901664733887]
at 2000 iteration, d_loss : [0.04252957925200462], g_loss : [5.5967302322387695]
at 3000 iteration, d_loss : [0.05793968588113785], g_loss : [5.736297607421875]
at 4000 iteration, d_loss : [0.12238318473100662], g_loss : [5.387004852294922]
at 5000 iteration, d_loss : [0.22798436880111694], g_loss : [4.826971530914307]
at 6000 iteration, d_loss : [0.49156278371810913], g_loss : [4.45762825012207]
at 7000 iteration, d_loss : [0.5076054930686951], g_loss : [3.598268508911133]
at 8000 iteration, d_loss : [0.5653466582298279], g_loss : [3.3538084030151367]
at 9000 iteration, d_loss : [0.450133353471756], g_loss : [3.3437085151672363]
at 10000 iteration, d_loss : [0.4747970700263977], g_loss : [3.0847530364990234]
at 11000 iteration, d_loss : [0.6612855195999146], g_loss : [2.967733383178711]
at 12000 iteration, d_loss : [0.5793889760971

KeyboardInterrupt: 