In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define GMMN"""
import tensorflow as tf
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_fc_mnist, enc_conv_mnist
from utils.viz import random_sample_grid


# data
batch_size = 256
train_steps = 1500
dim_noise = 100

data = mnist_eager("data/mnist_train", batch_size)


def noise_fn(n_samples): return tf.random_uniform([n_samples, dim_noise], minval=-1, maxval=1)


conv = False
if conv:
    model = gen_conv_mnist(use_bn=True)
else:
    model = gen_fc_mnist(use_bn=True)


def loss(real, generated):
    return tf.sqrt(compute_mmd(real, generated, [0.01, 0.03, 0.1, 0.3, 1.]))


opt = tf.train.AdamOptimizer()

In [None]:
"""Train"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        gen_batch = model(noise_fn(tf.shape(img_batch)[0]))
        mmd = loss(img_batch, gen_batch)
    grads = tape.gradient(mmd, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
grid = random_sample_grid(model, noise_fn, grid_dims=(4, 4))

In [None]:
"""Define GAN"""
if conv:
    discriminator = enc_conv_mnist(1, use_bn=True)
else:
    discriminator = enc_fc_mnist(1, use_bn=True)
generator = model
label_smoothing = 0.9


def partial_loss(logits, lbls):
    return tf.losses.sigmoid_cross_entropy(multi_class_labels=lbls, logits=logits)


gen_opt = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)
disc_opt = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)

In [None]:
"""Train GAN"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the gradients
    batch_dim = tf.shape(img_batch)[0]
    gen_batch = generator(noise_fn(batch_dim))
    #full_batch = tf.concat([img_batch, gen_batch], axis=0)
    #full_labels = tf.concat([0.9*tf.ones([batch_dim, 1]), tf.zeros([batch_dim, 1])], axis=0)
    real_labels = label_smoothing*tf.ones([batch_dim, 1])
    gen_labels = tf.zeros([batch_dim, 1])
    with tf.GradientTape() as d_tape:
        d_loss_real = partial_loss(discriminator(img_batch), real_labels)
        d_loss_fake = partial_loss(discriminator(gen_batch), gen_labels)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    
    # fresh generated batch for generator training
    with tf.GradientTape(watch_accessed_variables=False) as g_tape:
        for vari in generator.trainable_variables:
            g_tape.watch(vari)
        gen_only_batch = generator(noise_fn(2*batch_dim))
        g_loss = partial_loss(discriminator(gen_only_batch),
                              label_smoothing*tf.ones([2*batch_dim, 1]))
    g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
    gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
    
    if not step % 50:
        print("Step", step)
        print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
grid = random_sample_grid(generator, noise_fn, grid_dims=(4, 4))