In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import os

In [None]:
tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
MODEL_DIR = '../generated_output/GAN/'

In [None]:
LEARNING_RATE = 1e-4
TRAINING_STEPS = 30000
BATCH_SIZE = 100

In [None]:
IMAGE_DIM = 784
NOISE_DIM = 100
GEN_HIDDEN_DIM = [256]
DISC_HIDDEN_DIM = [256]

In [None]:
class GAN():

    def __init__(self, image_dim=IMAGE_DIM, noise_dim=NOISE_DIM, gen_hidden_dim=GEN_HIDDEN_DIM, disc_hidden_dim=DISC_HIDDEN_DIM):
        self.image_dim = image_dim
        self.noise_dim = noise_dim
        self.gen_hidden_dim = gen_hidden_dim
        self.disc_hidden_dim = disc_hidden_dim
    
    def _disc_model(self, features):
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            net = features
            for units in self.disc_hidden_dim:
                net = tf.layers.dense(net, units=units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())
            net = tf.layers.dense(net, 1, activation=tf.nn.sigmoid, kernel_initializer=tf.initializers.he_normal())
            return net

    def _gen_model(self, features):
        with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
            net = features
            for units in self.gen_hidden_dim:
                net = tf.layers.dense(net, units=units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())
            net = tf.layers.dense(net, self.image_dim, activation=tf.nn.sigmoid, kernel_initializer=tf.initializers.he_normal())
            return net

    def gan_model_fn(self, features, labels, mode, params):
        
        if mode == tf.estimator.ModeKeys.PREDICT:
            input_noise = features
            output_image = self._gen_model(input_noise)
            return tf.estimator.EstimatorSpec(mode, predictions=output_image)

        assert mode == tf.estimator.ModeKeys.TRAIN
        real_image = features
        fake_noise = tf.random.uniform(shape=[self.batch_size, self.noise_dim], minval=-1., maxval=1., dtype=tf.float32)
        fake_image = self._gen_model(fake_noise)
        disc_real = self._disc_model(real_image)
        disc_fake = self._disc_model(fake_image)
        disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))
        gen_loss = -tf.reduce_mean(tf.log(disc_fake))
        accuracy = tf.metrics.accuracy(labels=tf.zeros(shape=[self.batch_size], dtype=tf.float32),
                                    predictions=tf.cast((disc_fake > 0.5),tf.float32),
                                    name='acc_op')
        tf.summary.scalar('accuracy', accuracy[1])
        tf.summary.scalar('loss_gen', gen_loss)
        tf.summary.scalar('loss_disc', disc_loss)
        optimizer_disc = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        disc_train_op = optimizer_disc.minimize(disc_loss,var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="discriminator"), global_step=tf.train.get_global_step())
        gen_train_op = optimizer_gen.minimize(gen_loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="generator"), global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode, loss=gen_loss + disc_loss, train_op=tf.group(disc_train_op, gen_train_op))

    def batch(self, features, batch_size, is_training):
        self.batch_size = batch_size
        if is_training == True:
            count = None
        else:
            count = 1
        dataset = tf.data.Dataset.from_tensor_slices(features)
        return dataset.shuffle(features.shape[0]).repeat(count=count).batch(self.batch_size)

In [None]:
def random_image_plot(estimator, input_fn, seed=None):
    np.random.seed(seed)
    random_noise = np.random.uniform(-1., 1., size=[1, NOISE_DIM]).astype(np.float32)
    random_image = estimator.predict(input_fn=lambda:input_fn(random_noise, is_training=False, batch_size=1))
    print(random_image)
    for x in random_image:
        p = x.reshape([28, 28])
        plt.imshow(p, cmap = 'gray_r')
        plt.show()

In [None]:
gan_model = GAN()

In [None]:
gan_estimator = tf.estimator.Estimator(
    model_fn=gan_model.gan_model_fn,
    model_dir=MODEL_DIR
)

In [None]:
x_train = tf.keras.datasets.mnist.load_data()[0][0] / 255.
x_train = x_train.reshape([-1, IMAGE_DIM]).astype(np.float32)

In [None]:
gan_estimator.train(input_fn=lambda:gan_model.batch(x_train, BATCH_SIZE, is_training=True), steps=TRAINING_STEPS)

In [None]:
random_image_plot(gan_estimator, gan_model.batch)