In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define/initialize model"""
import tensorflow as tf
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist


# data
batch_size = 256
train_steps = 1500
noise = 100

data = mnist_eager("data/mnist_train", batch_size)


def noise_fn(shape): return tf.random_uniform(shape, minval=-1, maxval=1)


conv = False
if conv:
    model = gen_conv_mnist(use_bn=True)
else:
    model = gen_fc_mnist(use_bn=True)


def loss(imgs):
    generated = model(noise_fn((tf.shape(imgs)[0], noise)))
    return tf.sqrt(compute_mmd(imgs, generated, [0.01, 0.03, 0.1, 0.3, 1.]))


opt = tf.train.AdamOptimizer()

In [None]:
"""Train"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        mmd = loss(img_batch)
    grads = tape.gradient(mmd, model.variables)
    opt.apply_gradients(zip(grads, model.variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate stuff forever"""
while True:
    imgs = model(noise_fn((batch_size, 100)))
    for thing in imgs:
        plt.imshow(thing.numpy().reshape((32, 32)), cmap="Greys_r", vmin=0, vmax=1)
        plt.show()
        input()