In [1]:
import tensorflow as tf
import tensorflow.contrib.layers as layers
config = tf.ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, \
                        allow_soft_placement=True, device_count = {'CPU': 8})
sess = tf.InteractiveSession(config = config)

In [2]:
import tensorflow.contrib.distributions as ds

In [3]:
import numpy as np

In [4]:
# computes the KL divergence D_KL(P||Q) between the distributions of P and Q, 
# given an equal number of samples drawn from them
def KL_network(data_P, data_Q, name):
    with tf.variable_scope(name):  
        data_combined = tf.concat([data_P, data_Q], axis = 0)
        
        lay = layers.relu(data_combined, 20)
        lay = layers.relu(lay, 20)
        outputs = layers.linear(lay, 1)
        
    these_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = name)
    
    return outputs, these_vars

In [5]:
def KL_loss(KL_output, name):
    with tf.variable_scope(name):
        batch_size_dyn = tf.cast(tf.math.divide(tf.shape(KL_output)[0], 2), tf.int32)
        
        T_P = KL_output[:batch_size_dyn,:]
        T_Q = KL_output[batch_size_dyn:,:]
        
        TF_loss = -(tf.reduce_mean(T_P, axis = 0) - tf.math.log(tf.reduce_mean(tf.math.exp(T_Q), axis = 0)))
        
        TF_loss = TF_loss[0]
        
    return TF_loss

In [6]:
with tf.variable_scope('model_params'):
    mu = tf.Variable(2.0, 'mu')
    sigma = tf.Variable(3.0, 'sigma')

model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = 'model_params')

In [7]:
model = ds.Normal(loc = mu, scale = sigma)

In [8]:
samples_model = model.sample((50000, 1))

In [9]:
data_in = tf.placeholder(tf.float32, [None, 1], name = 'data_in')

In [10]:
KL_output, KL_vars = KL_network(data_in, samples_model, name = 'KL')

In [11]:
KL_lossval = KL_loss(KL_output, 'KL_loss')

In [12]:
train_KL = tf.train.AdamOptimizer(learning_rate = 0.005, beta1 = 0.3, beta2 = 0.5).minimize(KL_lossval, var_list = KL_vars)

In [13]:
train_model = tf.train.AdamOptimizer(learning_rate = 0.05, beta1 = 0.9, beta2 = 0.999).minimize(-KL_lossval, var_list = model_vars)

In [14]:
sess.run(tf.global_variables_initializer())

In [15]:
data_train = np.random.normal(loc = 0, scale = 1, size = 50000)

In [16]:
data_train = np.expand_dims(data_train, axis = 1)

In [17]:
for batch in range(50):
    sess.run(train_KL, feed_dict = {data_in: data_train})

In [18]:
for epoch in range(100):      
    sess.run(train_KL, feed_dict = {data_in: data_train})
    cur_KL = -sess.run(KL_lossval, feed_dict = {data_in: data_train})
    
    sess.run(train_model, feed_dict = {data_in: data_train})
    
    cur_mu = sess.run(mu)
    cur_sigma = sess.run(sigma)
    
    print("epoch {}: loss = {}".format(epoch, cur_KL))
    print("mu = {}, sigma = {}".format(cur_mu, cur_sigma))

epoch 0: loss = 0.8681875467300415
mu = 1.9500000476837158, sigma = 2.950000047683716
epoch 1: loss = 0.8681941628456116
mu = 1.900011420249939, sigma = 2.899944305419922
epoch 2: loss = 0.8428404331207275
mu = 1.8499565124511719, sigma = 2.8499794006347656
epoch 3: loss = 0.8198143243789673
mu = 1.8000038862228394, sigma = 2.799805164337158
epoch 4: loss = 0.7981064319610596
mu = 1.7500636577606201, sigma = 2.7494826316833496
epoch 5: loss = 0.7763665318489075
mu = 1.7000980377197266, sigma = 2.699099540710449
epoch 6: loss = 0.7622635960578918
mu = 1.650078296661377, sigma = 2.6485257148742676
epoch 7: loss = 0.7402007579803467
mu = 1.599927544593811, sigma = 2.5979161262512207
epoch 8: loss = 0.716325044631958
mu = 1.5497430562973022, sigma = 2.5470669269561768
epoch 9: loss = 0.6947574019432068
mu = 1.4993613958358765, sigma = 2.496091365814209
epoch 10: loss = 0.6784394383430481
mu = 1.448998212814331, sigma = 2.4448530673980713
epoch 11: loss = 0.6449662446975708
mu = 1.398416399