In [1]:
import tensorflow as tf
import tensorflow.contrib.layers as layers
config = tf.ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, \
                        allow_soft_placement=True, device_count = {'CPU': 8})
sess = tf.InteractiveSession(config = config)

In [2]:
import numpy as np

In [3]:
eps = 1e-6

In [4]:
def EM_network(data, name, reuse = tf.AUTO_REUSE):
    with tf.variable_scope(name, reuse = reuse):       
        lay = layers.relu(data, 40)
        lay = layers.relu(lay, 40)
        lay = layers.relu(lay, 20)
        outputs = layers.linear(lay, 1)
        
    these_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = name)
    
    return outputs, these_vars

In [5]:
def EM_loss(data_P, data_Q, name):
    local_EM_network_name = name + "_EM_net"
    
    T_P, EM_vars = EM_network(data_P, name = local_EM_network_name)
    T_Q, _ = EM_network(data_Q, name = local_EM_network_name)

    EM_loss = tf.reduce_mean(T_P, axis = 0) - tf.reduce_mean(T_Q, axis = 0)
    
    batch_size_dyn = tf.cast(tf.shape(T_P)[0], tf.int32)
    rand = tf.random.uniform(shape = (batch_size_dyn, 1), minval = 0.0, maxval = 1.0)
    
    # add gradient penalty
    x_grad = tf.math.add(tf.math.multiply(rand, data_P),
                         tf.math.multiply(tf.math.subtract(1.0, rand), data_Q))
    x_grad_EM, _ = EM_network(x_grad, name = local_EM_network_name)
    grad = tf.gradients(x_grad_EM, x_grad)[0]
    
    grad_norm = tf.math.sqrt(tf.reduce_sum(tf.math.square(grad), axis = 1) + eps)
    
    grad_pen = tf.reduce_mean(tf.math.square(grad_norm - 1.0))

    EM_loss_grad_pen = EM_loss[0] + 10 * grad_pen

    return EM_loss_grad_pen, EM_vars, EM_loss

In [6]:
data_P_in = tf.placeholder(tf.float32, [None, 1], name = 'data_P_in')
data_Q_in = tf.placeholder(tf.float32, [None, 1], name = 'data_Q_in')

In [7]:
EM_lossval, EM_vars, critic = EM_loss(data_P_in, data_Q_in, 'EM_loss')

In [8]:
train_EM = tf.train.AdamOptimizer(learning_rate = 0.005, beta1 = 0.9, beta2 = 0.9).minimize(EM_lossval, var_list = EM_vars)
#train_EM = tf.train.GradientDescentOptimizer(learning_rate = 0.005).minimize(EM_lossval, var_list = EM_vars)

In [36]:
data_Q_train = np.random.normal(loc = 1, scale = 1, size = 5000)
data_P_train = np.random.normal(loc = 2, scale = 2, size = 5000)

In [37]:
data_Q_train = np.expand_dims(data_Q_train, axis = 1)
data_P_train = np.expand_dims(data_P_train, axis = 1)

In [38]:
sess.run(tf.global_variables_initializer())

In [39]:
for epoch in range(1000):
    sess.run([train_EM], feed_dict = {data_P_in: data_P_train, data_Q_in: data_Q_train})
    loss = sess.run(critic, feed_dict = {data_P_in: data_P_train, data_Q_in: data_Q_train})
    
    if not epoch % 200:
        print("epoch {}: loss = {}".format(epoch, loss))

epoch 0: loss = [0.3300354]
epoch 200: loss = [-1.0564537]
epoch 400: loss = [-1.0475273]
epoch 600: loss = [-1.0427332]
epoch 800: loss = [-1.0726366]
