In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define GAN"""
import tensorflow as tf
tf.enable_eager_execution()

from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_conv_mnist, enc_fc_mnist
from utils.viz import random_sample_grid


# data
svhn = False
batch_size = 128  # this is a "half batch"!
train_steps = 5000 if svhn else 1500
dim_noise = 100
clip = 0.01
n_critic = 5
train_steps *= n_critic

if svhn:
    train_files = ["/cache/tfrs/svhn_train.tfr", "/cache/tfrs/svhn_extra.tfr"]
    test_files = ["/cache/tfrs/svhn_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["/cache/tfrs/mnist_train.tfr"]
    test_files = ["/cache/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if svhn else 60000)

conv = False
if conv:
    discriminator = enc_conv_mnist(1, use_bn=True, clip=clip)
    generator = gen_conv_mnist(use_bn=True, channels=3 if svhn else 1)
else:
    discriminator = enc_fc_mnist(1, use_bn=True, clip=clip)
    generator = gen_fc_mnist(use_bn=True, channels=3 if svhn else 1)


def noise_fn(n_samples): return tf.random_uniform([n_samples, dim_noise], minval=-1, maxval=1)


gen_opt = tf.train.RMSPropOptimizer(learning_rate=0.00005)
disc_opt = tf.train.RMSPropOptimizer(learning_rate=0.00005)

In [None]:
"""Train WGAN"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the gradients
    batch_dim = tf.shape(img_batch)[0]
    gen_batch = generator(noise_fn(batch_dim))
    gen_labels = tf.zeros([batch_dim, 1])
    with tf.GradientTape() as d_tape:
        d_loss_real = tf.reduce_mean(discriminator(img_batch))
        d_loss_fake = tf.reduce_mean(discriminator(gen_batch))
        d_loss = -d_loss_real + d_loss_fake
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    
    if not (step + 1) % n_critic:
        # fresh generated batch for generator training
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            for vari in generator.trainable_variables:
                g_tape.watch(vari)
            gen_only_batch = generator(noise_fn(2*batch_dim))
            g_loss = -tf.reduce_mean(discriminator(gen_only_batch))
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
    
    if not step % 50:
        print("Step", step)
        print("Disc Loss", d_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
tf.keras.backend.set_learning_phase(0)
grid = random_sample_grid(generator, noise_fn)