# CONDITIONAL GAN

- imports

In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xav

In [2]:
import numpy as np

## Graph

In [9]:
# clear graph
tf.reset_default_graph()

- Placeholders
    - x\_ : batch of images
    - y\_ : labels of x\_ as one-hot vectors
    - z\_ : noise (a 100-d vector)

In [11]:
x_ = tf.placeholder(tf.float32, shape=[None, 784], name='x')
y_ = tf.placeholder(tf.float32, shape=[None, 10], name='y')
z_ = tf.placeholder(tf.float32, shape=[None, 100], name='z')

- discriminator params

We are gonna concatenate (x,y) and (z,y). We need to adjust the shape of weight matrices to accomodate the change.

In [12]:
dw1 = tf.get_variable('dw1', [794, 128], tf.float32, initializer=xav())
dw2 = tf.get_variable('dw2', [128, 1], tf.float32, initializer=xav())
db1 = tf.get_variable('db1', [128], tf.float32, initializer=tf.constant_initializer(0.))
db2 = tf.get_variable('db2', [1], tf.float32, initializer=tf.constant_initializer(0.))

- generator params

In [13]:
gw1 = tf.get_variable('gw1', [110, 128], tf.float32, initializer=xav())
gw2 = tf.get_variable('gw2', [128, 784], tf.float32, initializer=xav())
gb1 = tf.get_variable('gb1', [128], tf.float32, initializer=tf.constant_initializer(0.))
gb2 = tf.get_variable('gb2', [784], tf.float32, initializer=tf.constant_initializer(0.))

- group params

In [15]:
theta_d, theta_g = [dw1, dw2, db1, db2], [gw1, gw2, gb1, gb2]

- build disciminator net

In [16]:
def D(x, y):
    xy = tf.concat(values=[x,y], axis=1)
    dh1 = tf.nn.relu(tf.matmul(xy, dw1) + db1)
    logit = tf.matmul(dh1, dw2) + db2
    prob = tf.nn.sigmoid(logit)
    
    return prob, logit

- build generator net

In [25]:
def G(z, y):
    zy = tf.concat(values=[z,y], axis=1)
    gh1 = tf.nn.relu(tf.matmul(zy, gw1) + gb1)
    log_prob = tf.matmul(gh1, gw2) + gb2
    prob = tf.nn.sigmoid(log_prob)
    
    return prob

- sample from generator

In [26]:
g_sample = G(z_, y_)

- run D on real images

In [28]:
real_prob, real_logit = D(x_, y_)

- run D on fake images (sampled from G)

In [30]:
fake_prob, fake_logit = D(g_sample, y_)

### Optimization

- real loss
    - probability ground truth = 1
    - predicted probability = real_prob
    - loss : cross entropy between ground truth and prediction

In [31]:
d_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logit, 
                                            labels=tf.ones_like(real_logit)))

- fake loss
    - ground truth = 0
    - prediction = fake_prob

In [32]:
d_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit,
                                            labels=tf.zeros_like(fake_logit)))

- discriminator loss, d_loss = real loss + fake loss

In [33]:
d_loss = d_loss_real + d_loss_fake

- generator loss
    - ground truth = 1 (generated image should be realistic)
    - prediction : discriminator's prediction on how real it is 

In [34]:
g_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, 
                                            labels=tf.ones_like(fake_logit)))

- training operations

In [35]:
d_train_op = tf.train.AdamOptimizer().minimize(d_loss, var_list=theta_d)
g_train_op = tf.train.AdamOptimizer().minimize(g_loss, var_list=theta_g)

### MNIST

In [43]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


## Training

- sample noise

In [37]:
sample_z = lambda n : np.random.uniform(-1., 1., size=[n, 100])

- sample y

In [53]:
def sample_y(n):
    sample_y_ = np.zeros(shape=[16, 10])
    sample_y_[:, n] = 1
    return sample_y_

- plot function

In [38]:
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec

In [39]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [57]:
def save_samples(samples, i, n):
    fig = plot(samples)
    plt.savefig('out/cgan/{}/{}.png'.format(n, str(i).zfill(3)), bbox_inches='tight')
    i += 1
    plt.close(fig)

- create session

In [67]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

- begin training loop

In [69]:
iterations = 80000 + 20000
batch_size = 128
# conditioned on one_hot(n)
n = 2
sample_y_ = sample_y(n)
for i in range(80000, iterations):
    if i%1000 == 0:
        samples = sess.run(g_sample, feed_dict= {
            z_ : sample_z(16),
            y_ : sample_y_
        })
        # save to file
        save_samples(samples, i, n)
    
    images, labels = mnist.train.next_batch(batch_size)
    
    _, d_loss_value = sess.run([d_train_op, d_loss], feed_dict={
        x_ : images,
        y_ : labels,
        z_ : sample_z(batch_size)
    })
    
    _, g_loss_value = sess.run([g_train_op, g_loss], feed_dict={
        z_ : sample_z(batch_size),
        y_ : labels
    })
    
    if i%1000 == 0:
        print('[{}] D : [{}], G : [{}]'.
             format(i, d_loss_value, g_loss_value))

[80000] D : [0.8616932034492493], G : [1.9244520664215088]
[81000] D : [0.7964981198310852], G : [1.8869316577911377]
[82000] D : [0.6386680603027344], G : [2.070434093475342]
[83000] D : [0.7320753931999207], G : [1.9634184837341309]
[84000] D : [0.7591575384140015], G : [1.920872449874878]
[85000] D : [0.8043198585510254], G : [1.8156242370605469]
[86000] D : [0.7657531499862671], G : [1.8715001344680786]
[87000] D : [0.6911740899085999], G : [2.3547914028167725]
[88000] D : [0.6986382007598877], G : [2.0119125843048096]
[89000] D : [0.7594764232635498], G : [1.798782467842102]
[90000] D : [0.7464478015899658], G : [1.9600447416305542]
[91000] D : [0.734268069267273], G : [1.782613754272461]
[92000] D : [0.6961671710014343], G : [1.7710468769073486]
[93000] D : [0.7325962781906128], G : [1.960585355758667]
[94000] D : [0.7399799823760986], G : [2.3859081268310547]
[95000] D : [0.6773284673690796], G : [1.602203607559204]
[96000] D : [0.7625189423561096], G : [1.5817159414291382]
[970

- theme setup

In [70]:
from IPython.core.display import HTML
def css():
    style = open("css/custom.css", "r").read()
    return HTML(style)
css()