In [None]:
"""Imports, define GAN"""
import tensorflow as tf

from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_conv_mnist, enc_fc_mnist
from utils.viz import random_sample_grid, interpolate, imshow, img_grid_npy


# data
batch_size = 128  # this is a "half batch"!
train_steps = 5000
dim_code = 64
label_smoothing = 0.9
coeff = 10.

strain_files = ["/cache/tfrs/svhn_train.tfr"]
stest_files = ["/cache/tfrs/svhn_test.tfr"]
sparse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))

mtrain_files = ["/cache/tfrs/mnist_train.tfr"]
mtest_files = ["/cache/tfrs/mnist_test.tfr"]
mparse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

mnist_data = tfr_dataset_eager(mtrain_files, batch_size, mparse_fn, shufrep=60000)
svhn_data = tfr_dataset_eager(strain_files, batch_size, sparse_fn, shufrep=60000)

data = tf.data.Dataset.zip((mnist_data, svhn_data))


conv = True
if conv:
    m_to_c = enc_conv_mnist(dim_code, use_bn=True)
    c_to_s = gen_conv_mnist(use_bn=True, channels=3)
    discriminator_s = enc_conv_mnist(1, use_bn=True)

    s_to_c = enc_conv_mnist(dim_code, use_bn=True)
    c_to_m = gen_conv_mnist(use_bn=True, channels=1)
    discriminator_m = enc_conv_mnist(1, use_bn=True)
else:
    m_to_c = enc_fc_mnist(dim_code, use_bn=True)
    c_to_s = gen_fc_mnist(use_bn=True, channels=3)
    discriminator_s = enc_fc_mnist(1, use_bn=True)

    s_to_c = enc_fc_mnist(dim_code, use_bn=True)
    c_to_m = gen_fc_mnist(use_bn=True, channels=1)
    discriminator_m = enc_fc_mnist(1, use_bn=True)
m_to_s = lambda x: c_to_s(m_to_c(x))
s_to_m = lambda x: c_to_m(s_to_c(x))


def partial_loss(logits, lbls):
    return tf.losses.sigmoid_cross_entropy(multi_class_labels=lbls, logits=logits)


def ls_loss(out, lbls):
    return tf.losses.mean_squared_error(out, lbls)


gen_opt = tf.optimizers.Adam(learning_rate=0.0002, beta1=0.5)
disc_opt = tf.optimizers.Adam(learning_rate=0.0002, beta1=0.5)

In [None]:
tf.keras.backend.set_learning_phase(1)
enum_data = enumerate(data)
for step, ((m_batch, _), (s_batch, _)) in enum_data:
    if step > train_steps:
        break

    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the gradients
    batch_dim_m = tf.shape(m_batch)[0]
    batch_dim_s = tf.shape(s_batch)[0]
    
    gen_m_batch = s_to_m(s_batch)
    gen_s_batch = m_to_s(m_batch)
    
    #print("Shape m, s", tf.shape(m_batch), tf.shape(s_batch))
    #print("Shape tom, tos", tf.shape(gen_m_batch), tf.shape(gen_s_batch))
    #input()
    
    real_labels_m = label_smoothing*tf.ones([batch_dim_m, 1])
    gen_labels_m = tf.zeros([batch_dim_s, 1])
    real_labels_s = label_smoothing*tf.ones([batch_dim_s, 1])
    gen_labels_s = tf.zeros([batch_dim_m, 1])
    
    with tf.GradientTape() as d_tape:
        d_m_loss_real = partial_loss(discriminator_m(m_batch), real_labels_m)
        d_m_loss_fake = partial_loss(discriminator_m(gen_m_batch), gen_labels_m)
        d_m_loss = 0.5 * (d_m_loss_real + d_m_loss_fake)
        
        d_s_loss_real = partial_loss(discriminator_s(s_batch), real_labels_s)
        d_s_loss_fake = partial_loss(discriminator_s(gen_s_batch), gen_labels_s)
        d_s_loss = 0.5 * (d_s_loss_real + d_s_loss_fake)
        
        d_loss = d_m_loss + d_s_loss

    d_vars = discriminator_m.trainable_variables + discriminator_s.trainable_variables
    d_grads = d_tape.gradient(d_loss, d_vars)
    disc_opt.apply_gradients(zip(d_grads, d_vars))
    
    # fresh generated batch for generator training
    with tf.GradientTape(watch_accessed_variables=False) as g_tape:
        g_vars = m_to_c.trainable_variables + c_to_s.trainable_variables + s_to_c.trainable_variables + c_to_m.trainable_variables
        for vari in g_vars:
            g_tape.watch(vari)
        _, ((m_batch2, _), (s_batch2, _)) = next(enum_data)
        gen_only_s_batch = m_to_s(m_batch2)
        gen_only_m_batch = s_to_m(s_batch2)
        
        g_m_loss = partial_loss(discriminator_m(gen_only_m_batch),
                              label_smoothing*tf.ones([tf.shape(s_batch2)[0], 1]))
        g_s_loss = partial_loss(discriminator_s(gen_only_s_batch),
                              label_smoothing*tf.ones([tf.shape(m_batch2)[0], 1]))
        g_loss = g_m_loss + g_s_loss
        
        back_to_m_batch = s_to_m(gen_only_s_batch)
        back_to_s_batch = m_to_s(gen_only_m_batch)
        cyc_loss = tf.reduce_mean(tf.abs(m_batch2 - back_to_m_batch)) + tf.reduce_mean(tf.abs(s_batch2 - back_to_s_batch))
        wow_loss = g_loss + coeff * cyc_loss
    g_grads = g_tape.gradient(wow_loss, g_vars)
    gen_opt.apply_gradients(zip(g_grads, g_vars))
    
    if not step % 50:
        print("Step", step)
        print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)
        print("Cycle Loss", cyc_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
tf.keras.backend.set_learning_phase(0)

test_datam = tfr_dataset_eager(mtest_files, 7*7, mparse_fn)
test_datas = tfr_dataset_eager(stest_files, 7*7, sparse_fn)
test_data = tf.data.Dataset.zip((test_datam, test_datas))

for (m_batch, _), (s_batch, _) in test_data:
    gen_s = m_to_s(m_batch)
    gen_m = s_to_m(s_batch)
    back_to_m = s_to_m(gen_s)
    back_to_s = m_to_s(gen_m)
    
    print("Input batches")
    grid_s_orig = img_grid_npy(s_batch.numpy(), 7, 7, normalize=False)
    imshow(grid_s_orig)
    grid_m_orig = img_grid_npy(m_batch.numpy(), 7, 7, normalize=False)
    imshow(grid_m_orig)
    
    print("Conversion results")
    grid_m = img_grid_npy(gen_m.numpy(), 7, 7, normalize=False)
    imshow(grid_m)
    grid_s = img_grid_npy(gen_s.numpy(), 7, 7, normalize=False)
    imshow(grid_s)
    
    print("Back to original")
    grid_s_back = img_grid_npy(back_to_s.numpy(), 7, 7, normalize=False)
    imshow(grid_s_back)
    grid_m_back = img_grid_npy(back_to_m.numpy(), 7, 7, normalize=False)
    imshow(grid_m_back)
    input()