In [134]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

tf.logging.set_verbosity(tf.logging.INFO) # version 차이로 생기는 warning 메시지 뜨지 않도록 해줌.

In [135]:
# hyperparameter setting
mb_size = 32
X_dim = 784
z_dim = 64
h_dim = 128
lr = 1e-3
m = 5
lam = 1e-3
gamma = 0.5
k_curr = 0

In [136]:
# Load MNIST datasets
mnist = input_data.read_data_sets('data/mnist', one_hot=True)

Extracting data/mnist/train-images-idx3-ubyte.gz
Extracting data/mnist/train-labels-idx1-ubyte.gz
Extracting data/mnist/t10k-images-idx3-ubyte.gz
Extracting data/mnist/t10k-labels-idx1-ubyte.gz


In [137]:
# visualization
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [138]:
def xavier_init(size): # xavier initialization
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [139]:
# Define placeholder
X = tf.placeholder(tf.float32, shape=[None, X_dim]) # MNIST RealImage = 28*28
z = tf.placeholder(tf.float32, shape=[None, z_dim]) # Noise Dimension = 64
k = tf.placeholder(tf.float32)

In [140]:
# ********* Generator-Network (Hidden Node # = 256) 
# Hidden Layer parameter - weights & biases
G_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

# Output Layer parameter - weights & biases
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

# Generator

def G(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob

In [141]:
# ********* Discriminator-Network (Hidden Node # = 256)
# Hidden Layer parameter - weights & biases
D_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
# Output Layer parameter - weights & biases
D_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
D_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

# Discriminator

def D(X):
    D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
    X_recon = tf.matmul(D_h1, D_W2) + D_b2
    return tf.reduce_mean(tf.reduce_sum((X - X_recon)**2, 1))

In [142]:
# make noise-Latent Variable의 input으로 사용할 noise를 Uniform Distribution에서 batch_size만큼 샘플링
def sample_z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])  # m = Test Sample Size, n = Noise Dimension

In [143]:
# ********* Generation, Loss, Optimization and Session Init.
G_sample = G(z)

D_real = D(X)
D_fake = D(G_sample)
# Discriminator의 loss function
D_loss = D_real - k*D_fake
# Generator의 loss function
G_loss = D_fake

# Discriminator와 Generator의 Optimizer
optim_D = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(D_loss, var_list=[D_W1, D_W2, D_b1, D_b2]))
optim_G = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(G_loss, var_list=[G_W1, G_W2, G_b1, G_b2]))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [145]:
# ********* Training and Testing

if not os.path.exists('out/'):
    os.makedirs('out/')  # make outputs folder

i = 0
num_epoch = 10000

for epoch in range(num_epoch): # 10000 = Num. of Epoch
    # load MNIST images
    batch_xs, _ = mnist.train.next_batch(mb_size) # mb_size = Batch Size
    
    _, D_real_curr = sess.run(
        [optim_D, D_real],
        feed_dict={X: batch_xs, z: sample_z(mb_size, z_dim), k: k_curr}
    )

    _, D_fake_curr = sess.run(
        [optim_G, D_fake],
        feed_dict={X: batch_xs, z: sample_z(mb_size, z_dim)}
    )

    k_curr = k_curr + lam * (gamma*D_real_curr - D_fake_curr)
    
 
    if epoch % 100 == 0: # 100 = Saving period
        measure = D_real_curr + np.abs(gamma*D_real_curr - D_fake_curr)

        print('Iter-{}; Convergence measure: {:.4}'
              .format(epoch, measure))
      
        samples = sess.run(G_sample, feed_dict={z: sample_z(16, z_dim)})
        #print('Epoch:', '%04d' % i, 'D_loss: {:.4}'.format(D_real_curr), 'G_loss: {:.4}'.format(D_fake_curr))

        fig = plot(samples)
        plt.savefig('out/{}.png'
                    .format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

Iter-0; Convergence measure: 30.89
Epoch: 0000 D_loss: 28.07 G_loss: 11.21
Iter-100; Convergence measure: 26.97
Epoch: 0001 D_loss: 26.21 G_loss: 12.34
Iter-200; Convergence measure: 24.98
Epoch: 0002 D_loss: 24.87 G_loss: 12.55
Iter-300; Convergence measure: 29.28
Epoch: 0003 D_loss: 26.36 G_loss: 16.11
Iter-400; Convergence measure: 30.72
Epoch: 0004 D_loss: 29.38 G_loss: 13.36
Iter-500; Convergence measure: 28.53
Epoch: 0005 D_loss: 27.78 G_loss: 13.14
Iter-600; Convergence measure: 32.77
Epoch: 0006 D_loss: 28.3 G_loss: 18.62
Iter-700; Convergence measure: 29.19
Epoch: 0007 D_loss: 26.87 G_loss: 11.11
Iter-800; Convergence measure: 24.22
Epoch: 0008 D_loss: 23.97 G_loss: 12.23
Iter-900; Convergence measure: 26.66
Epoch: 0009 D_loss: 26.39 G_loss: 12.93
Iter-1000; Convergence measure: 28.86
Epoch: 0010 D_loss: 26.92 G_loss: 15.4
Iter-1100; Convergence measure: 25.82
Epoch: 0011 D_loss: 25.37 G_loss: 13.14
Iter-1200; Convergence measure: 24.42
Epoch: 0012 D_loss: 23.31 G_loss: 12.77
