In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports"""
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd


# data
batch_size = 128
train_steps = 1500
dim_noise = 100

data = mnist_eager("data/mnist_train", batch_size)

In [None]:
"""Define GAN"""
discriminator = tf.keras.Sequential([tf.keras.layers.Dense(512, tf.nn.relu),
                                     tf.keras.layers.Dense(256, tf.nn.relu),
                                     tf.keras.layers.Dense(1)])
generator = tf.keras.Sequential([tf.keras.layers.Dense(256, tf.nn.relu),
                                 tf.keras.layers.Dense(512, tf.nn.relu),
                                 tf.keras.layers.Dense(1024, tf.nn.relu),
                                 tf.keras.layers.Dense(784, tf.nn.sigmoid)])


def noise_fn(shape): return tf.random_uniform(shape, minval=-1, maxval=1)


def partial_loss(logits, lbls):
    return tf.losses.sigmoid_cross_entropy(multi_class_labels=lbls, logits=logits)


gen_opt = tf.train.AdamOptimizer()
disc_opt = tf.train.AdamOptimizer()

In [None]:
"""Train GAN"""
# TODO only "watch" variables of interest in each tape
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    # prepare mixed batch for discriminator training
    batch_dim = tf.shape(img_batch)[0]
    gen_batch = generator(noise_fn([batch_dim, dim_noise]))
    full_batch = tf.concat([img_batch, gen_batch], axis=0)
    full_labels = tf.concat([tf.ones([batch_dim, 1]), tf.zeros([batch_dim, 1])], axis=0)
    with tf.GradientTape() as d_tape:
        d_loss = partial_loss(discriminator(full_batch), full_labels)
    d_grads = d_tape.gradient(d_loss, discriminator.variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.variables))
    
    # fresh generated batch for generator training
    with tf.GradientTape() as g_tape:
        gen_batch = generator(noise_fn([batch_dim, dim_noise]))
        g_loss = partial_loss(discriminator(gen_batch), tf.ones([batch_dim, 1]))
    g_grads = g_tape.gradient(g_loss, generator.variables)
    gen_opt.apply_gradients(zip(g_grads, generator.variables))
    
    if not step % 50:
        print("Step", step)
        print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)


In [None]:
"""Generate stuff forever"""
while True:
    imgs = generator(noise_fn((batch_size, dim_noise)))
    for thing in imgs:
        plt.imshow(thing.numpy().reshape((28, 28)), cmap="Greys_r", vmin=0, vmax=1)
        plt.show()
        input()