In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import math
from skimage.io import imsave
import os
import shutil

In [2]:
# def test():
#   z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
#   x_generated, _ = build_generator(z_prior)
#   chkpt_fname = tf.train.latest_checkpoint(output_path)

#   init = tf.global_variables_initializer()
#   sess = tf.Session()
#   saver = tf.train.Saver()
#   sess.run(init)
#   saver.restore(sess, chkpt_fname)
#   z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
#   x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
  
#   show_result(x_gen_val, os.path.join(output_path, "test_result.jpg"))

def weight_bias(shape, name):
  w = tf.Variable(tf.truncated_normal(shape=shape, stddev=1.0 / math.sqrt(shape[0])), name=name % 'w')
  b = tf.Variable(tf.constant(0.1, shape=[shape[1]]), name=name % 'b')
  
  return w, b

def generator(z_prior):
  w1, b1 = weight_bias([z_size, h1_size], name='g_%s1')
  w2, b2 = weight_bias([h1_size, h2_size], name='g_%s2')
  w3, b3 = weight_bias([h2_size, img_size], name='g_%s3')
  g_params = [w1, b1, w2, b2, w3, b3]
  
  h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
  h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
  h3 = tf.nn.tanh(tf.matmul(h2, w3) + b3)
  x_generated = h3
  
  return x_generated, g_params

def discriminator(x_data, x_generated, keep_prob):
  x_in = tf.concat([x_data, x_generated], 0)
  w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
  b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
  h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
  w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
  b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
  h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
  w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
  b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
  h3 = tf.matmul(h2, w3) + b3
  y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
  y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
  d_params = [w1, b1, w2, b2, w3, b3]
  return y_data, y_generated, d_params

def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
  batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
  img_h, img_w = batch_res.shape[1], batch_res.shape[2]
  grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
  grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
  img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
  
  for i, res in enumerate(batch_res):
    if i >= grid_size[0] * grid_size[1]:
      break
    img = (res) * 255
    img = img.astype(np.uint8)
    row = (i // grid_size[0]) * (img_h + grid_pad)
    col = (i % grid_size[1]) * (img_w + grid_pad)
    img_grid[row:row + img_h, col:col + img_w] = img
    
  imsave(fname, img_grid)

### Load data

In [3]:
data_path = os.path.expanduser('~/Datasets/mnist')
mnist = input_data.read_data_sets(data_path, one_hot=True)

Extracting /Users/v-shmyhlo/Datasets/mnist/train-images-idx3-ubyte.gz
Extracting /Users/v-shmyhlo/Datasets/mnist/train-labels-idx1-ubyte.gz
Extracting /Users/v-shmyhlo/Datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting /Users/v-shmyhlo/Datasets/mnist/t10k-labels-idx1-ubyte.gz


### Build a graph

In [4]:
img_height = 28
img_width = 28
img_size = img_height * img_width
h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256

x_data = tf.placeholder(tf.float32, [batch_size, img_size], name='x_data')
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name='z_prior')
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
global_step = tf.Variable(0, name='global_step', trainable=False)

x_generated, g_params = generator(z_prior)
y_data, y_generated, d_params = discriminator(x_data, x_generated, keep_prob)

d_loss = -(tf.log(y_data) + tf.log(1 - y_generated))
g_loss = -tf.log(y_generated)

optimizer = tf.train.AdamOptimizer(learning_rate)

d_train = optimizer.minimize(d_loss, var_list=d_params)
g_train = optimizer.minimize(g_loss, var_list=g_params)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

### Train

In [None]:
restore = True
output_path = 'output'
model_name = os.path.join(output_path,'model')
max_epoch = 20
# steps = math.floor(mnist.train.num_examples / batch_size)
steps = 3000
log_interval = 200
d_train_interval = 1
g_train_interval = 1
kp = 0.75
lr = 0.0001

with tf.Session() as sess:
  if restore:
    chkpt_fname = tf.train.latest_checkpoint(output_path)
    saver.restore(sess, chkpt_fname)
  else:
    sess.run(init)
    
    if os.path.exists(output_path):
      shutil.rmtree(output_path)
    os.mkdir(output_path)

  z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)

  for i in range(sess.run(global_step), max_epoch):
    for j in range(steps):
      x_train, _ = mnist.train.next_batch(batch_size)
      x_train = 2 * x_train.astype(np.float32) - 1
      z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
      
      if j % d_train_interval == 0:
        sess.run(d_train,
                 feed_dict={x_data: x_train, z_prior: z_value, keep_prob: kp, learning_rate: lr})
      
      if j % g_train_interval == 0:
        sess.run(g_train,
                 feed_dict={x_data: x_train, z_prior: z_value, keep_prob: kp, learning_rate: lr})
        
      if j % log_interval == 0:
        g_l, d_l = sess.run([tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)], 
                            feed_dict={x_data: mnist.validation.images[:batch_size], 
                                       z_prior: z_value, 
                                       keep_prob: 1})
        print('epoch: %d/%d, iteration: %d/%d, loss: %f (g_loss: %f, d_loss: %f)' % 
              (i, max_epoch, j, steps, g_l + d_l, g_l, d_l))

    x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
    show_result(x_gen_val, os.path.join(output_path, 'sample%s.jpg' % i))
    z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
    x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
    show_result(x_gen_val, os.path.join(output_path, 'random_sample%s.jpg' % i))
    sess.run(tf.assign(global_step, i + 1))
    save_path = saver.save(sess, model_name, global_step=global_step)
    print('model saved: %s', save_path)

output/model-10
INFO:tensorflow:Restoring parameters from output/model-10
epoch: 10/20, iteration: 0/3000, loss: 16.176888 (g_loss: 14.402351, d_loss: 1.774536)
epoch: 10/20, iteration: 200/3000, loss: 16.041840 (g_loss: 14.263771, d_loss: 1.778069)
epoch: 10/20, iteration: 400/3000, loss: 16.600319 (g_loss: 14.743519, d_loss: 1.856800)
epoch: 10/20, iteration: 600/3000, loss: 16.164673 (g_loss: 14.401324, d_loss: 1.763348)
epoch: 10/20, iteration: 800/3000, loss: 16.019434 (g_loss: 14.198654, d_loss: 1.820780)
epoch: 10/20, iteration: 1000/3000, loss: 15.920187 (g_loss: 14.133940, d_loss: 1.786248)
epoch: 10/20, iteration: 1200/3000, loss: 16.137709 (g_loss: 14.297401, d_loss: 1.840307)
epoch: 10/20, iteration: 1400/3000, loss: 14.976318 (g_loss: 12.914227, d_loss: 2.062092)
epoch: 10/20, iteration: 1600/3000, loss: 13.588550 (g_loss: 11.722174, d_loss: 1.866376)
epoch: 10/20, iteration: 1800/3000, loss: 14.024271 (g_loss: 12.248671, d_loss: 1.775601)
epoch: 10/20, iteration: 2000/300