In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define GAN"""
import tensorflow as tf
tf.enable_eager_execution()

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_conv_mnist, enc_fc_mnist
from utils.viz import random_sample_grid


# data
batch_size = 128  # this is a "half batch"!
train_steps = 1500
dim_noise = 100
penalty = 10
n_critic = 5
train_steps *= n_critic

data = mnist_eager("data/mnist_train", batch_size)

conv = False
if conv:
    discriminator = enc_conv_mnist(1, use_bn=False)
    generator = gen_conv_mnist(use_bn=True)
else:
    discriminator = enc_fc_mnist(1, use_bn=False)
    generator = gen_fc_mnist(use_bn=True)


def noise_fn(n_samples): return tf.random_uniform([n_samples, dim_noise], minval=-1, maxval=1)


def partial_loss(logits, lbls):
    return tf.losses.sigmoid_cross_entropy(multi_class_labels=lbls, logits=logits)


gen_opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0., beta2=0.9)
disc_opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0., beta2=0.9)

In [None]:
"""Train WGAN"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the gradients
    batch_dim = tf.shape(img_batch)[0]
    gen_batch = generator(noise_fn(batch_dim))
    gen_labels = tf.zeros([batch_dim, 1])
    
    interps = tf.random_uniform([tf.shape(img_batch)[0], 1, 1, 1], minval=0,
                                maxval=1)
    interp_batch = interps * img_batch + (1 - interps) * gen_batch
    with tf.GradientTape() as d_tape:
        d_loss_real = tf.reduce_mean(discriminator(img_batch))
        d_loss_fake = tf.reduce_mean(discriminator(gen_batch))
        
        with tf.GradientTape(watch_accessed_variables=False) as p_tape:
            p_tape.watch(interp_batch)
            d_loss_interp = tf.reduce_mean(discriminator(interp_batch))
        interp_grads = tf.reshape(p_tape.gradient(d_loss_interp, interp_batch),
                                  [batch_dim, -1])
        diff1 =  tf.squared_difference(tf.norm(interp_grads, axis=1), 1)
        
        d_loss = -d_loss_real + d_loss_fake + penalty * tf.reduce_mean(diff1)
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    
    if not (step + 1) % n_critic:
        # fresh generated batch for generator training
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            for vari in generator.trainable_variables:
                g_tape.watch(vari)
            gen_only_batch = generator(noise_fn(2*batch_dim))
            g_loss = -tf.reduce_mean(discriminator(gen_only_batch))
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
    
    if not step % 50:
        print("Step", step)
        #print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
grid = random_sample_grid(generator, noise_fn, grid_dims=(4, 4))