In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import time, itertools

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config = config)

batch_size = 100
num_epoch = 10
lr = 0.0002

  from ._conv import register_converters as _register_converters


In [2]:
def lrelu(x, leak=0.2, name='lrelu'):
    return tf.maximum(x, x*leak)

In [3]:
def discriminator(x, isTrain=True, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        conv1 = tf.layers.conv2d(x, 128, [4, 4], strides=(2, 2), padding='same')
        lrelu1 = lrelu(conv1)
        
        conv2 = tf.layers.conv2d(lrelu1, 256, [4, 4], strides=(2, 2), padding='same')
        lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain))
        
        conv3 = tf.layers.conv2d(lrelu2, 512, [4, 4], strides=(2, 2), padding='same')
        lrelu3 = lrelu(tf.layers.batch_normalization(conv3, training=isTrain))
        
        conv4 = tf.layers.conv2d(lrelu3, 1024, [4, 4], strides=(2, 2), padding='same')
        lrelu4 = lrelu(tf.layers.batch_normalization(conv4, training=isTrain))
        
        conv5 = tf.layers.conv2d(lrelu4, 1, [4, 4], strides=(2, 2), padding='valid')
        out = tf.nn.sigmoid(conv5)
        return out, conv5

In [4]:
def generator(z, isTrain=True, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        conv1 = tf.layers.conv2d_transpose(z, 1024, [4, 4], strides=(1, 1), padding='valid')
        lrelu1 = lrelu(tf.layers.batch_normalization(conv1, training=isTrain))
        
        conv2 = tf.layers.conv2d_transpose(lrelu1, 512, [4, 4], strides=(2, 2), padding='same')
        lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain))
        
        conv3 = tf.layers.conv2d_transpose(lrelu2, 256, [4, 4], strides=(2, 2), padding='same')
        lrelu3 = lrelu(tf.layers.batch_normalization(conv3, training=isTrain))
        
        conv4 = tf.layers.conv2d_transpose(lrelu3, 128, [4, 4], strides=(2, 2), padding='same')
        lrelu4 = lrelu(tf.layers.batch_normalization(conv4, training=isTrain))
        
        conv5 = tf.layers.conv2d_transpose(lrelu4, 1, [4, 4], strides=(2, 2), padding='same')
        out = tf.nn.tanh(conv5)
        return out

In [5]:
X = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
Z = tf.placeholder(dtype=tf.float32, shape=[None, 1, 1, 100])
isTrain = tf.placeholder(dtype=tf.bool)

G_sample = generator(Z, isTrain)

D_real, D_real_logit = discriminator(X, isTrain)
D_fake, D_fake_logit = discriminator(G_sample, isTrain, True)

# sigmoid_cross_entropy_with_logits会给logits做softmax，所以我们需要在discriminator中返回logits
D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logit,
                                                                     labels=tf.ones_like(D_real_logit)))
D_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logit,
                                                                     labels=tf.zeros_like(D_fake_logit)))
D_loss = D_real_loss + D_fake_loss

G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logit,
                                                                labels=tf.ones_like(D_fake_logit)))

T_vars = tf.trainable_variables()
D_vars = [var for var in T_vars if var.name.startswith('discriminator')]
G_vars = [var for var in T_vars if var.name.startswith('generator')]

D_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(D_loss, var_list=D_vars)
G_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(G_loss, var_list=G_vars)

tf.summary.scalar('D_loss', D_loss)
tf.summary.scalar('G_loss', G_loss)

<tf.Tensor 'G_loss:0' shape=() dtype=string>

In [6]:
log_dir = '../summary/graph'
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)

sess.run(tf.global_variables_initializer())
np.random.seed(int(time.time()))

In [7]:
mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True, reshape=[])
train_set = tf.image.resize_images(mnist.train.images, [64, 64]).eval(session=sess)
train_set = (train_set - 0.5) / 0.5

Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz


In [8]:
def plot(datas):
    plt.figure(figsize=(30, 30))
    for i in range(len(datas)):
        plt.subplot(16, 8, i + 1)
        plt.imshow(np.reshape(datas[i], (64, 64)), cmap='gray')
        plt.axis('off')
    plt.subplots_adjust(wspace =0, hspace =0)
    plt.show()

In [9]:
fixed_z = np.random.normal(0, 1, (25, 1, 1, 100))
def show_result(num_epoch, show=False, save=False, path='result-test.png'):
    test_images = sess.run(G_sample, feed_dict={Z:fixed_z, isTrain:False})
    
    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
        
    for k in range(size_figure_grid * size_figure_grid):
        i = k // size_figure_grid
        j = k % size_figure_grid
        ax[i, j].cla()
        ax[i, j].imshow(np.reshape(test_images[k], (64, 64)), cmap='gray')
        
    label = 'Epoch {}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')

    if save:
        plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()

In [14]:
for i in range(num_epoch):
    G_losses = []
    D_losses = []
    for iter in range(mnist.train.num_examples // batch_size):
        X_batch = train_set[iter*batch_size:(iter+1)*batch_size]
        Z_batch = np.random.normal(0, 1, (batch_size, 1, 1, 100))
        _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={X: X_batch, Z: Z_batch, isTrain: True})
        
        Z_batch = np.random.normal(0, 1, (batch_size, 1, 1, 100))
        _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={Z: Z_batch, isTrain: True})
        Z_batch = np.random.normal(0, 1, (batch_size, 1, 1, 100))
        _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={Z: Z_batch, isTrain: True})
        
        G_losses.append(G_loss_curr)
        D_losses.append(D_loss_curr)
        
    Z_batch = np.random.normal(0, 1, (batch_size, 1, 1, 100))
    print('i {}, D_loss {} G_loss {}'.format(i, np.mean(D_losses), np.mean(G_losses)))
    fixed_p = 'Fixed_results_stesha' + str(i + 1) + '.png'
    show_result((i + 1), save=True, path=fixed_p)
    
    summary_str = sess.run(merged, feed_dict={X: X_batch, Z: Z_batch, isTrain: False})
    train_writer.add_summary(summary_str, i)
train_writer.close()

i 0, D_loss 0.5232036709785461 G_loss 3.4085400104522705
i 1, D_loss 0.5357067584991455 G_loss 3.3828465938568115
i 2, D_loss 0.1861976534128189 G_loss 6.957078456878662
i 3, D_loss 0.44939014315605164 G_loss 3.9451053142547607
i 4, D_loss 0.41956207156181335 G_loss 3.8183064460754395
i 5, D_loss 0.4272448718547821 G_loss 4.014602184295654
i 6, D_loss 0.3636464476585388 G_loss 4.720570087432861
i 7, D_loss 0.4718671143054962 G_loss 3.672618865966797
i 8, D_loss 0.4818500280380249 G_loss 3.832390785217285
i 9, D_loss 0.2419673502445221 G_loss 5.054630279541016
