# Generative Adversarial Networks (GANs)

This notebook attempts to recreate examples of GANs from several blogs found online. Note: all comments made with '###' are my own, and the are from the code graciously borrowed from the tutorials.

## I. Augustinus Kristiadi's GAN

This VAE was written by Augustinus Kristiadi and demonstrated at https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/.

The exact implementation can be found at https://github.com/wiseodd/generative-models/blob/master/GAN/vanilla_gan/gan_tensorflow.py

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import time
import os


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)


def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [7]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
# Discriminator Net
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [4]:
# Generator Net
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [5]:
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

In [6]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

mb_size = 128
Z_dim = 100

In [9]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('_output/'):
    os.makedirs('_output/')

i = 0
start = time.time()

for it in range(1000000):
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('_output/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('Elapsed Time: {:.2f} mins'.format((time.time() - start) / 60))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
Elapsed Time: 0.01 mins
D loss: 1.89
G_loss: 1.635

Iter: 1000
Elapsed Time: 0.15 mins
D loss: 0.01807
G_loss: 8.811

Iter: 2000
Elapsed Time: 0.30 mins
D loss: 0.01422
G_loss: 5.458

Iter: 3000
Elapsed Time: 0.44 mins
D loss: 0.1594
G_loss: 4.536

Iter: 4000
Elapsed Time: 0.59 mins
D loss: 0.1236
G_loss: 4.847

Iter: 5000
Elapsed Time: 0.74 mins
D loss: 0.183
G_loss: 5.523

Iter: 6000
Elapsed Time: 0.88 mins
D loss: 0.3142
G_loss: 4.697

Iter: 7000
Elapsed Time: 1.03 mins
D loss: 0.3163
G_loss: 5.381

Iter: 8000
Elapsed Time: 1.18 mins
D loss: 0.3023
G_loss: 4.33

Iter: 9000
Elapsed Time: 1.32 mins
D loss: 0.4493
G_loss: 4.126

Iter: 10000
Elapsed Time: 1.47 mins
D loss: 0.7047
G_loss: 3.415

Iter: 11000
Elapsed Time: 1.61 mins
D loss: 0.5081
G_loss: 2.751

Iter: 12000
Elapsed Time: 1.76 mins
D loss: 0.652
G_loss: 2.728

Iter: 13000
Elapsed Time: 1.90 mins
D loss: 0.5715
G_loss: 2.899

Iter: 14000
Elapsed Time: 2.05 mins
D loss: 0.7607
G_loss: 2.853

Iter: 15000
Elapsed Time: 

Iter: 124000
Elapsed Time: 18.12 mins
D loss: 0.4422
G_loss: 2.282

Iter: 125000
Elapsed Time: 18.26 mins
D loss: 0.5391
G_loss: 2.75

Iter: 126000
Elapsed Time: 18.41 mins
D loss: 0.4859
G_loss: 2.373

Iter: 127000
Elapsed Time: 18.56 mins
D loss: 0.4911
G_loss: 3.114

Iter: 128000
Elapsed Time: 18.70 mins
D loss: 0.5057
G_loss: 2.602

Iter: 129000
Elapsed Time: 18.85 mins
D loss: 0.4329
G_loss: 2.426

Iter: 130000
Elapsed Time: 18.99 mins
D loss: 0.5981
G_loss: 2.653

Iter: 131000
Elapsed Time: 19.14 mins
D loss: 0.4396
G_loss: 2.576

Iter: 132000
Elapsed Time: 19.29 mins
D loss: 0.4552
G_loss: 2.924

Iter: 133000
Elapsed Time: 19.43 mins
D loss: 0.4194
G_loss: 2.671

Iter: 134000
Elapsed Time: 19.56 mins
D loss: 0.4681
G_loss: 2.881

Iter: 135000
Elapsed Time: 19.70 mins
D loss: 0.4462
G_loss: 2.372

Iter: 136000
Elapsed Time: 19.84 mins
D loss: 0.4736
G_loss: 2.5

Iter: 137000
Elapsed Time: 19.99 mins
D loss: 0.6483
G_loss: 2.825

Iter: 138000
Elapsed Time: 20.16 mins
D loss: 0.417

Iter: 245000
Elapsed Time: 35.78 mins
D loss: 0.3697
G_loss: 2.804

Iter: 246000
Elapsed Time: 35.92 mins
D loss: 0.3394
G_loss: 2.972

Iter: 247000
Elapsed Time: 36.07 mins
D loss: 0.4186
G_loss: 2.83

Iter: 248000
Elapsed Time: 36.21 mins
D loss: 0.3148
G_loss: 2.819

Iter: 249000
Elapsed Time: 36.36 mins
D loss: 0.3764
G_loss: 2.639

Iter: 250000
Elapsed Time: 36.50 mins
D loss: 0.3471
G_loss: 2.83

Iter: 251000
Elapsed Time: 36.65 mins
D loss: 0.479
G_loss: 2.806

Iter: 252000
Elapsed Time: 36.79 mins
D loss: 0.3013
G_loss: 3.01

Iter: 253000
Elapsed Time: 36.94 mins
D loss: 0.4545
G_loss: 2.893

Iter: 254000
Elapsed Time: 37.08 mins
D loss: 0.4001
G_loss: 2.571

Iter: 255000
Elapsed Time: 37.23 mins
D loss: 0.285
G_loss: 2.627

Iter: 256000
Elapsed Time: 37.37 mins
D loss: 0.4104
G_loss: 3.141

Iter: 257000
Elapsed Time: 37.52 mins
D loss: 0.3637
G_loss: 3.098

Iter: 258000
Elapsed Time: 37.66 mins
D loss: 0.3173
G_loss: 2.866

Iter: 259000
Elapsed Time: 37.81 mins
D loss: 0.3069


Iter: 366000
Elapsed Time: 53.46 mins
D loss: 0.2596
G_loss: 2.971

Iter: 367000
Elapsed Time: 53.61 mins
D loss: 0.2696
G_loss: 2.961

Iter: 368000
Elapsed Time: 53.75 mins
D loss: 0.3655
G_loss: 2.971

Iter: 369000
Elapsed Time: 53.90 mins
D loss: 0.2845
G_loss: 3.272

Iter: 370000
Elapsed Time: 54.04 mins
D loss: 0.2252
G_loss: 3.314

Iter: 371000
Elapsed Time: 54.18 mins
D loss: 0.3442
G_loss: 3.404

Iter: 372000
Elapsed Time: 54.33 mins
D loss: 0.1726
G_loss: 3.224

Iter: 373000
Elapsed Time: 54.47 mins
D loss: 0.1665
G_loss: 3.28

Iter: 374000
Elapsed Time: 54.62 mins
D loss: 0.2529
G_loss: 3.312

Iter: 375000
Elapsed Time: 54.76 mins
D loss: 0.201
G_loss: 3.301

Iter: 376000
Elapsed Time: 54.91 mins
D loss: 0.2987
G_loss: 3.469

Iter: 377000
Elapsed Time: 55.06 mins
D loss: 0.2196
G_loss: 2.806

Iter: 378000
Elapsed Time: 55.20 mins
D loss: 0.3097
G_loss: 2.804

Iter: 379000
Elapsed Time: 55.35 mins
D loss: 0.2279
G_loss: 3.159

Iter: 380000
Elapsed Time: 55.49 mins
D loss: 0.30

Iter: 487000
Elapsed Time: 71.10 mins
D loss: 0.2215
G_loss: 3.576

Iter: 488000
Elapsed Time: 71.24 mins
D loss: 0.1889
G_loss: 3.639

Iter: 489000
Elapsed Time: 71.39 mins
D loss: 0.1962
G_loss: 2.939

Iter: 490000
Elapsed Time: 71.53 mins
D loss: 0.1434
G_loss: 3.482

Iter: 491000
Elapsed Time: 71.68 mins
D loss: 0.27
G_loss: 3.244

Iter: 492000
Elapsed Time: 71.82 mins
D loss: 0.1893
G_loss: 3.199

Iter: 493000
Elapsed Time: 71.97 mins
D loss: 0.2652
G_loss: 3.655

Iter: 494000
Elapsed Time: 72.11 mins
D loss: 0.2589
G_loss: 3.65

Iter: 495000
Elapsed Time: 72.26 mins
D loss: 0.1979
G_loss: 3.703

Iter: 496000
Elapsed Time: 72.40 mins
D loss: 0.2683
G_loss: 3.673

Iter: 497000
Elapsed Time: 72.55 mins
D loss: 0.1869
G_loss: 3.74

Iter: 498000
Elapsed Time: 72.70 mins
D loss: 0.1877
G_loss: 3.021

Iter: 499000
Elapsed Time: 72.84 mins
D loss: 0.1897
G_loss: 3.541

Iter: 500000
Elapsed Time: 72.99 mins
D loss: 0.2102
G_loss: 3.772

Iter: 501000
Elapsed Time: 73.13 mins
D loss: 0.172


Iter: 608000
Elapsed Time: 88.76 mins
D loss: 0.3265
G_loss: 3.577

Iter: 609000
Elapsed Time: 88.91 mins
D loss: 0.1868
G_loss: 3.591

Iter: 610000
Elapsed Time: 89.05 mins
D loss: 0.3096
G_loss: 3.432

Iter: 611000
Elapsed Time: 89.20 mins
D loss: 0.364
G_loss: 3.669

Iter: 612000
Elapsed Time: 89.35 mins
D loss: 0.1709
G_loss: 3.553

Iter: 613000
Elapsed Time: 89.49 mins
D loss: 0.1983
G_loss: 3.285

Iter: 614000
Elapsed Time: 89.64 mins
D loss: 0.1259
G_loss: 3.209

Iter: 615000
Elapsed Time: 89.78 mins
D loss: 0.1717
G_loss: 3.799

Iter: 616000
Elapsed Time: 89.92 mins
D loss: 0.2383
G_loss: 3.482

Iter: 617000
Elapsed Time: 90.07 mins
D loss: 0.2499
G_loss: 3.857

Iter: 618000
Elapsed Time: 90.22 mins
D loss: 0.1761
G_loss: 3.382

Iter: 619000
Elapsed Time: 90.36 mins
D loss: 0.1022
G_loss: 3.747

Iter: 620000
Elapsed Time: 90.50 mins
D loss: 0.1842
G_loss: 3.782

Iter: 621000
Elapsed Time: 90.65 mins
D loss: 0.1848
G_loss: 3.454

Iter: 622000
Elapsed Time: 90.79 mins
D loss: 0.1

Iter: 729000
Elapsed Time: 106.43 mins
D loss: 0.2668
G_loss: 3.682

Iter: 730000
Elapsed Time: 106.58 mins
D loss: 0.2066
G_loss: 3.925

Iter: 731000
Elapsed Time: 106.73 mins
D loss: 0.2322
G_loss: 3.702

Iter: 732000
Elapsed Time: 106.87 mins
D loss: 0.1233
G_loss: 4.071

Iter: 733000
Elapsed Time: 107.02 mins
D loss: 0.1439
G_loss: 3.252

Iter: 734000
Elapsed Time: 107.16 mins
D loss: 0.1476
G_loss: 4.219

Iter: 735000
Elapsed Time: 107.31 mins
D loss: 0.08445
G_loss: 4.153

Iter: 736000
Elapsed Time: 107.45 mins
D loss: 0.1815
G_loss: 3.729

Iter: 737000
Elapsed Time: 107.60 mins
D loss: 0.201
G_loss: 4.173

Iter: 738000
Elapsed Time: 107.74 mins
D loss: 0.2094
G_loss: 3.857

Iter: 739000
Elapsed Time: 107.89 mins
D loss: 0.2293
G_loss: 3.754

Iter: 740000
Elapsed Time: 108.03 mins
D loss: 0.2102
G_loss: 4.21

Iter: 741000
Elapsed Time: 108.18 mins
D loss: 0.159
G_loss: 3.98

Iter: 742000
Elapsed Time: 108.32 mins
D loss: 0.1019
G_loss: 4.298

Iter: 743000
Elapsed Time: 108.47 min

Iter: 848000
Elapsed Time: 123.70 mins
D loss: 0.2198
G_loss: 3.447

Iter: 849000
Elapsed Time: 123.85 mins
D loss: 0.2844
G_loss: 3.621

Iter: 850000
Elapsed Time: 123.99 mins
D loss: 0.1968
G_loss: 3.194

Iter: 851000
Elapsed Time: 124.14 mins
D loss: 0.3541
G_loss: 3.198

Iter: 852000
Elapsed Time: 124.46 mins
D loss: 0.1893
G_loss: 3.443

Iter: 853000
Elapsed Time: 124.61 mins
D loss: 0.2176
G_loss: 3.474

Iter: 854000
Elapsed Time: 124.75 mins
D loss: 0.1599
G_loss: 3.274

Iter: 855000
Elapsed Time: 124.90 mins
D loss: 0.3016
G_loss: 3.426

Iter: 856000
Elapsed Time: 125.04 mins
D loss: 0.2627
G_loss: 3.262

Iter: 857000
Elapsed Time: 125.19 mins
D loss: 0.2199
G_loss: 3.463

Iter: 858000
Elapsed Time: 125.33 mins
D loss: 0.278
G_loss: 3.066

Iter: 859000
Elapsed Time: 125.48 mins
D loss: 0.2237
G_loss: 3.347

Iter: 860000
Elapsed Time: 125.62 mins
D loss: 0.3106
G_loss: 3.044

Iter: 861000
Elapsed Time: 125.77 mins
D loss: 0.2148
G_loss: 3.252

Iter: 862000
Elapsed Time: 125.91 m

Iter: 968000
Elapsed Time: 141.29 mins
D loss: 0.2391
G_loss: 3.529

Iter: 969000
Elapsed Time: 141.43 mins
D loss: 0.1883
G_loss: 3.554

Iter: 970000
Elapsed Time: 141.58 mins
D loss: 0.2657
G_loss: 3.713

Iter: 971000
Elapsed Time: 141.72 mins
D loss: 0.2346
G_loss: 3.782

Iter: 972000
Elapsed Time: 141.87 mins
D loss: 0.1658
G_loss: 4.154

Iter: 973000
Elapsed Time: 142.01 mins
D loss: 0.2722
G_loss: 3.481

Iter: 974000
Elapsed Time: 142.16 mins
D loss: 0.1704
G_loss: 3.778

Iter: 975000
Elapsed Time: 142.31 mins
D loss: 0.1229
G_loss: 3.725

Iter: 976000
Elapsed Time: 142.45 mins
D loss: 0.216
G_loss: 3.672

Iter: 977000
Elapsed Time: 142.60 mins
D loss: 0.1534
G_loss: 3.864

Iter: 978000
Elapsed Time: 142.74 mins
D loss: 0.1422
G_loss: 3.616

Iter: 979000
Elapsed Time: 142.89 mins
D loss: 0.1714
G_loss: 3.392

Iter: 980000
Elapsed Time: 143.03 mins
D loss: 0.08585
G_loss: 3.634

Iter: 981000
Elapsed Time: 143.18 mins
D loss: 0.114
G_loss: 3.808

Iter: 982000
Elapsed Time: 143.32 m

### Results

Interestingly, the GAN learned how to produce numbers very quickly (within the first couple hundred epochs). However, it soon settled on 1 as its favorite number and optimized itself to draw those.

10
![title](_output/010.png)

20
![title](_output/020.png)

50
![title](_output/050.png)


100
![title](_output/100.png)

150
![title](_output/150.png)

200
![title](_output/200.png)

300
![title](_output/300.png)

400
![title](_output/400.png)

500
![title](_output/500.png)

600
![title](_output/600.png)

700
![title](_output/700.png)

800
![title](_output/800.png)

900
![title](_output/900.png)

999
![title](_output/999.png)