In [1]:
import os
import tensorflow as tf
import numpy as np
import cv2
import random
import scipy.misc
from utils import *

  from ._conv import register_converters as _register_converters


In [9]:
BATCH_SIZE = 64
EPOCH = 5000
HEIGHT, WIDTH, CHANNEL = 28,28,1

In [16]:
def process_data(): 
    from keras.datasets import mnist
    (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
    image_batch = X_train[np.random.randint(0, X_train.shape[0], size=BATCH_SIZE)]
    num_images = len(X_train)
    return image_batch , num_images


In [6]:
def generator(input, random_dim, is_train, reuse=False): 

    with tf.variable_scope('gen') as scope: 
        if reuse:
            scope.reuse_variables()
        w1 = tf.get_variable('w1', shape=[random_dim, 1024], dtype=tf.float32,
                             initializer=tf.contrib.layers.xavier_initializer())
        b1 =  b1 = tf.get_variable('b1', shape=[1024], dtype=tf.float32,
                             initializer=tf.constant_initializer(1.0))
        flat_conv1 = tf.add(tf.matmul(input , w1) , b1 , name = 'flat_conv1')
        act1 = tf.nn.tanh(flat_conv1, name='act1')

        dense1 = tf.layers.dense(bias_initializer=tf.ones_initializer(),inputs=act1, units=128*7*7, activation=tf.nn.tanh , kernel_initializer = tf.contrib.layers.xavier_initializer())
        bn1 = tf.contrib.layers.batch_norm(dense1, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn1')
        act2 = tf.nn.tanh(bn1, name='act1')

        conv1 = tf.reshape(act2, shape=[-1,7,7,128], name='conv1')
        up1 = tf.keras.layers.UpSampling2D(size = (2,2))

        conv2 = tf.layers.conv2d(bias_initializer=tf.ones_initializer(),inputs=up1,filters=64,kernel_size=[5, 5],padding="same",activation=tf.nn.tanh , kernel_initializer = tf.contrib.layers.xavier_initializer())
        up2 = tf.keras.layers.UpSampling2D(size = (2,2))
        conv3 = tf.layers.conv2d(bias_initializer=tf.ones_initializer(),inputs=up2,filters =1,kernel_size=[5, 5],padding= "same",activation=tf.nn.tanh , kernel_initializer = tf.contrib.layers.xavier_initializer())
              
        return conv3

In [7]:
def discriminator(input, is_train, reuse=False):
    with tf.variable_scope('dis') as scope:
        if reuse:
            scope.reuse_variables()

        #Convolution, activation, bias, repeat! 
        conv1 = tf.layers.conv2d(input, 64, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer = tf.contrib.layers.xavier_initializer(),
                                 name='conv1',bias_initializer=tf.ones_initializer())
        bn1 = tf.contrib.layers.batch_norm(conv1, is_training = is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope = 'bn1')
        act1 = tf.nn.tanh(conv1, n='act1')
         #Convolution, activation, bias, repeat! 
        conv2 = tf.layers.conv2d(act1, 128, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer = tf.contrib.layers.xavier_initializer(),
                                 name='conv2',bias_initializer=tf.ones_initializer())
        bn2 = tf.contrib.layers.batch_norm(conv2, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn2')
        act2 = tf.nn.tanh(bn2, n='act2')
        
        dim = int(np.prod(act2.get_shape()[1:]))
        fc1 = tf.reshape(act2, shape=[-1, dim], name='fc1')
        
        dense1 = tf.layers.dense(bias_initializer=tf.ones_initializer(),inputs=fc1, units=1024, activation=tf.nn.relu , kernel_initializer = tf.contrib.layers.xavier_initializer())
        dense2 = tf.layers.dense(bias_initializer=tf.ones_initializer(),inputs=dense2, units=1, activation=tf.nn.sigmoid , kernel_initializer = tf.contrib.layers.xavier_initializer())
        
    return dense2

In [None]:
def train():
    random_dim = 100
    print(os.environ['CUDA_VISIBLE_DEVICES'])
    
    with tf.variable_scope('input'):
        #real and fake image placholders
        real_image = tf.placeholder(tf.float32, shape = [None, HEIGHT, WIDTH, CHANNEL], name='real_image')
        random_input = tf.placeholder(tf.float32, shape=[None, random_dim], name='rand_input')
        is_train = tf.placeholder(tf.bool, name='is_train')
    
    # wgan
    fake_image = generator(random_input, random_dim, is_train)
    
    real_result = discriminator(real_image, is_train)
    fake_result = discriminator(fake_image, is_train, reuse=True)
    
    d_loss = tf.reduce_mean(fake_result) - tf.reduce_mean(real_result)  # This optimizes the discriminator.
    g_loss = -tf.reduce_mean(fake_result)  # This optimizes the generator.
            

    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'dis' in var.name]
    g_vars = [var for var in t_vars if 'gen' in var.name]
    # test
    # print(d_vars)
    trainer_d = tf.train.RMSPropOptimizer(learning_rate=2e-4).minimize(d_loss, var_list=d_vars)
    trainer_g = tf.train.RMSPropOptimizer(learning_rate=2e-4).minimize(g_loss, var_list=g_vars)
    # clip discriminator weights
    d_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in d_vars]

    
    batch_size = BATCH_SIZE
    image_batch, samples_num = process_data()
    
    batch_num = int(samples_num / batch_size)
    total_batch = 0
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # continue training
    save_path = saver.save(sess, "/tmp/model.ckpt")
    ckpt = tf.train.latest_checkpoint('./model/' + version)
    saver.restore(sess, save_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print('total training sample num:%d' % samples_num)
    print('batch size: %d, batch num per epoch: %d, epoch num: %d' % (batch_size, batch_num, EPOCH))
    print('start training...')
    for i in range(EPOCH):
        print(i)
        for j in range(batch_num):
            print(j)
            d_iters = 5
            g_iters = 1

            train_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
            for k in range(d_iters):
                print(k)
                train_image = sess.run(image_batch)
                #wgan clip weights
                sess.run(d_clip)
                
                # Update the discriminator
                _, dLoss = sess.run([trainer_d, d_loss],
                                    feed_dict={random_input: train_noise, real_image: train_image, is_train: True})

            # Update the generator
            for k in range(g_iters):
                # train_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
                _, gLoss = sess.run([trainer_g, g_loss],
                                    feed_dict={random_input: train_noise, is_train: True})

            # print 'train:[%d/%d],d_loss:%f,g_loss:%f' % (i, j, dLoss, gLoss)
            
        # save check point every 500 epoch
        if i%500 == 0:
            if not os.path.exists('./model/' + version):
                os.makedirs('./model/' + version)
            saver.save(sess, './model/' +version + '/' + str(i))  
        if i%50 == 0:
            # save images
            if not os.path.exists(newPoke_path):
                os.makedirs(newPoke_path)
            sample_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
            imgtest = sess.run(fake_image, feed_dict={random_input: sample_noise, is_train: False})
            # imgtest = imgtest * 255.0
            # imgtest.astype(np.uint8)
            save_images(imgtest, [8,8] ,newPoke_path + '/epoch' + str(i) + '.jpg')
            
            print('train:[%d],d_loss:%f,g_loss:%f' % (i, dLoss, gLoss))
    coord.request_stop()
coord.join(threads)