In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define/initialize model"""
import tensorflow as tf

from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist
from utils.viz import random_sample_grid, interpolate


# data
svhn = False
batch_size = 256
train_steps = 5000 if svhn else 1500
dim_noise = 100

if svhn:
    train_files = ["/cache/tfrs/svhn_train.tfr", "/cache/tfrs/svhn_extra.tfr"]
    test_files = ["/cache/tfrs/svhn_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["/cache/tfrs/mnist_train.tfr"]
    test_files = ["/cache/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if svhn else 60000)


def noise_fn(n_samples): return tf.random.uniform([n_samples, dim_noise], minval=-1, maxval=1)


conv = False
if conv:
    model = gen_conv_mnist(use_bn=True, channels=3 if svhn else 1)
else:
    model = gen_fc_mnist(use_bn=True, channels=3 if svhn else 1)


def loss(real, generated):
    return tf.sqrt(compute_mmd(real, generated, [0.01, 0.03, 0.1, 0.3, 1.]))


opt = tf.optimizers.Adam()

In [None]:
"""Check target loss"""
print("MMD for 15 pairs of real batches...")
data_iter = iter(data)
for step in range(15):
    mmd = loss(next(data_iter)[0], next(data_iter)[0])
    print(mmd.numpy())

In [None]:
"""Train"""

@tf.function
def train(batch):
    batch_dim = tf.shape(batch)[0]
    with tf.GradientTape() as tape:
        gen_batch = model(noise_fn(batch_dim))
        mmd = loss(tf.reshape(batch, [batch_dim, -1]),
                   tf.reshape(gen_batch, [batch_dim, -1]))
    grads = tape.gradient(mmd, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    
    return mmd


tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    mmd = train(img_batch)
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
tf.keras.backend.set_learning_phase(0)

grid = random_sample_grid(model, noise_fn)

In [None]:
"""Interpolation behavior"""
codes = noise_fn(2)
a_code = codes[0]
b_code = codes[1]
interpolate(a_code, b_code, gen=model, method="slerp")