In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./mnist/data/', one_hot=True)

Extracting ./mnist/data/train-images-idx3-ubyte.gz
Extracting ./mnist/data/train-labels-idx1-ubyte.gz
Extracting ./mnist/data/t10k-images-idx3-ubyte.gz
Extracting ./mnist/data/t10k-labels-idx1-ubyte.gz


- learning_rate = 최적화 함수에서 사용할 학습률
- training_epoch = 전체 데이터를 학습할 총횟수
- batch_size = 미니배치로 한번에 학습할 데이터의 개수
- n_hidden = 은닉층의 뉴런 개수
- n_input = 입력값의 크기


In [2]:
total_epoch = 100
batch_size = 100
n_hidden = 256
n_input = 28 * 28
n_noise = 128
n_class = 10

In [3]:
X = tf.placeholder(tf.float32, [None, n_input])
Y = tf.placeholder(tf.float32, [None, n_class])
Z = tf.placeholder(tf.float32, [None, n_noise])

In [4]:
def generator(noise, labels):
    with tf.variable_scope('generator'):
        inputs = tf.concat([noise, labels], 1)
        
        hidden = tf.layers.dense(inputs, n_hidden,
                                activation=tf.nn.relu)
        output = tf.layers.dense(hidden, n_input,
                                activation=tf.nn.sigmoid)
    
    return output

In [5]:
def discriminator(inputs, labels, reuse=None):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()
            
        inputs = tf.concat([inputs, labels], 1)
        
        hidden = tf.layers.dense(inputs, n_hidden,
                                activation=tf.nn.relu)
        output = tf.layers.dense(hidden, 1, 
                                activation=None)
    
    return output

In [6]:
def get_noise(batch_size, n_noise):
    return np.random.uniform(-1., 1., size=[batch_size, n_noise])

In [7]:
G = generator(Z, Y)
D_real = discriminator(X, Y)
D_gene = discriminator(G, Y, True)

In [8]:
loss_D_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_real, labels=tf.ones_like(D_real)))

loss_D_gene = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_gene, labels=tf.zeros_like(D_gene)))

loss_D = loss_D_real + loss_D_gene

In [9]:
loss_G = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=D_gene, labels=tf.ones_like(D_gene)))

In [10]:
vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                              scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                              scope='generator')

train_D = tf.train.AdamOptimizer().minimize(loss_D, var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G, var_list=vars_G)

In [12]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

total_batch = int(mnist.train.num_examples / batch_size)
loss_val_D, loss_val_G = 0, 0

for epoch in range(total_epoch):
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        noise = get_noise(batch_size, n_noise)
        _, loss_val_D = sess.run([train_D, loss_D],
                              feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
        _, loss_val_G = sess.run([train_G, loss_G],
                              feed_dict={Y:batch_ys, Z: noise})
        
    print('Epoch:', '%4d' % epoch,
         'D loss:', '{:.4}'.format(loss_val_D),
         'G loss:', '{:.4}'.format(loss_val_G))
    
    if epoch == 0 or (epoch + 1) % 10 == 0:
        sample_size = 10
        noise = get_noise(sample_size, n_noise)
        samples = sess.run(G, feed_dict={Y: mnist.test.labels[:sample_size],
                                         Z: noise})
        
        fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
        
        for i in range(sample_size):
            ax[0][i].set_axis_off()
            ax[1][i].set_axis_off()
            
            ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
            ax[1][i].imshow(np.reshape(samples[i], (28, 28)))
        
        plt.savefig('samples2/{}.png'.format(str(epoch).zfill(3)),
                               bbox_inches='tight')
        plt.close(fig)
        
print('finish!')

Epoch:    0 D loss: 0.004585 G loss: 7.773
Epoch:    1 D loss: 0.0308 G loss: 7.74
Epoch:    2 D loss: 0.01284 G loss: 8.905
Epoch:    3 D loss: 0.006581 G loss: 7.453
Epoch:    4 D loss: 0.01334 G loss: 7.014
Epoch:    5 D loss: 0.009981 G loss: 7.713
Epoch:    6 D loss: 0.03764 G loss: 6.429
Epoch:    7 D loss: 0.1532 G loss: 8.505
Epoch:    8 D loss: 0.06401 G loss: 6.99
Epoch:    9 D loss: 0.09657 G loss: 7.587
Epoch:   10 D loss: 0.1595 G loss: 6.574
Epoch:   11 D loss: 0.2381 G loss: 5.432
Epoch:   12 D loss: 0.2345 G loss: 4.763
Epoch:   13 D loss: 0.3427 G loss: 4.72
Epoch:   14 D loss: 0.3134 G loss: 4.56
Epoch:   15 D loss: 0.4904 G loss: 3.631
Epoch:   16 D loss: 0.5014 G loss: 4.203
Epoch:   17 D loss: 0.6498 G loss: 3.499
Epoch:   18 D loss: 0.5144 G loss: 3.638
Epoch:   19 D loss: 0.5789 G loss: 4.146
Epoch:   20 D loss: 0.4627 G loss: 2.965
Epoch:   21 D loss: 0.5027 G loss: 3.578
Epoch:   22 D loss: 0.8223 G loss: 2.846
Epoch:   23 D loss: 0.5248 G loss: 3.752
Epoch:   