In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define/initialize model"""
import tensorflow as tf
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd


# data
batch_size = 256
train_steps = 1500
noise = 100

data = mnist_eager("data/mnist_train", batch_size)


def noise_fn(shape): return tf.random_uniform(shape, minval=-1, maxval=1)


model = tf.keras.Sequential([tf.keras.layers.Dense(256, tf.nn.relu),
                             tf.keras.layers.Dense(512, tf.nn.relu),
                             tf.keras.layers.Dense(1024, tf.nn.relu),
                             tf.keras.layers.Dense(784, tf.nn.sigmoid)])


def loss(imgs):
    generated = model(noise_fn((tf.shape(imgs)[0], noise)))
    return tf.sqrt(compute_mmd(imgs, generated, [0.1]))


opt = tf.train.AdamOptimizer()

In [None]:
"""Train"""
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        mmd = loss(img_batch)
    grads = tape.gradient(mmd, model.variables)
    opt.apply_gradients(zip(grads, model.variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)


In [None]:
"""Generate stuff forever"""
while True:
    imgs = model(noise_fn((batch_size, 100)))
    for thing in imgs:
        plt.imshow(thing.numpy().reshape((28, 28)), cmap="Greys_r", vmin=0, vmax=1)
        plt.show()
        input()