In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define GAN"""
from itertools import product

import tensorflow as tf
import numpy as np

from utils.cppn import make_grid, format_batch
from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import enc_conv_mnist, enc_fc_mnist
from utils.viz import random_sample_grid, img_grid_npy, imshow


# data
svhn = True
batch_size = 256  # this is a "half batch"!
train_steps = 5000 if svhn else 1500
dim_noise = 100
label_smoothing = 0.9

if svhn:
    train_files = ["/cache/tfrs/svhn_train.tfr", "/cache/tfrs/svhn_extra.tfr"]
    test_files = ["/cache/tfrs/svhn_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["/cache/tfrs/mnist_train.tfr"]
    test_files = ["/cache/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if svhn else 60000)

conv = False
if conv:
    discriminator = enc_conv_mnist(1, use_bn=True)
else:
    discriminator = enc_fc_mnist(1, use_bn=True)
    
poses = make_grid(32, 2)
poses = tf.constant(poses, dtype=tf.float32)


cppn = tf.keras.Sequential([tf.keras.layers.Dense(12, tf.sin),
                             #tf.keras.layers.Dense(6, tf.sin),
                             tf.keras.layers.Dense(12, lambda x: tf.exp(-x**2)),
                             tf.keras.layers.Dense(3 if svhn else 1, lambda x: tf.nn.sigmoid(5*x))])
cppn.build((None, 2+dim_noise))


def noise_fn(n_samples): return tf.random.uniform([n_samples, dim_noise], minval=-1, maxval=1)


#loss = tf.losses.BinaryCrossentropy(from_logits=True)
loss = tf.losses.MeanSquaredError()

gen_opt = tf.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
disc_opt = tf.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

In [None]:
"""Train GAN"""

@tf.function
def train(batch, g_loss):
    batch_dim = tf.shape(batch)[0]
    if True:
        # prepare mixed batch for discriminator training
        # For batchnorm to work better, we feed only real images, then only 
        # generated ones and then average the gradients
        
        noise = noise_fn(batch_dim)
        combined_batch = format_batch(noise, poses)
        
        gen_batch = cppn(combined_batch)
        gen_batch = tf.reshape(gen_batch, [batch_dim, 32, 32, 3 if svhn else 1])
        
        real_labels = label_smoothing*tf.ones([batch_dim, 1])
        gen_labels = tf.zeros([batch_dim, 1])
        with tf.GradientTape() as d_tape:
            d_loss_real = loss(real_labels, discriminator(batch))
            d_loss_fake = loss(gen_labels, discriminator(gen_batch))
            d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
        disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    else:
        d_loss = 0.0
    
    # fresh generated batch for generator training
    with tf.GradientTape(watch_accessed_variables=False) as g_tape:
        for vari in cppn.trainable_variables:
            g_tape.watch(vari)
            
        noise2 = noise_fn(2*batch_dim)
        combined_batch2 = format_batch(noise2, poses)
        gen_only_batch = cppn(combined_batch2)
        gen_only_batch = tf.reshape(gen_only_batch, [2*batch_dim, 32, 32, 3 if svhn else 1])
        g_loss = loss(label_smoothing*tf.ones([2*batch_dim, 1]),
                      discriminator(gen_only_batch))
    g_grads = g_tape.gradient(g_loss, cppn.trainable_variables)
    gen_opt.apply_gradients(zip(g_grads, cppn.trainable_variables))
    
    return g_loss, d_loss


tf.keras.backend.set_learning_phase(1)
g_loss = tf.constant(0.5, dtype=tf.float32)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    g_loss, d_loss = train(img_batch, g_loss)
    
    if not step % 50:
        print("Step", step)
        print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""

tf.keras.backend.set_learning_phase(0)

noise = noise_fn(7*7)
pose_batch = tf.tile(poses, [7*7, 1])
noise_batch = repeat(noise, tf.shape(poses)[0])
combined_batch = tf.concat([pose_batch, noise_batch], axis=1)

gen_batch = cppn(combined_batch)
gen_batch = tf.reshape(gen_batch, [7*7, 32, 32, -1])

grid = img_grid_npy(gen_batch.numpy(), 7, 7, normalize=False)
imshow(grid)

In [None]:
tf.keras.backend.set_learning_phase(0)

hr = 1024
noise = noise_fn(1)
poseshr = make_grid(hr, 2, -256, 255)
poseshr = tf.constant(poseshr, dtype=tf.float32)

combined_batch = format_batch(noise, poseshr)

gen_batch = cppn(combined_batch)
gen_batch = tf.reshape(gen_batch, [hr, hr, -1])

imshow(gen_batch, figsize=(20, 20))