# L2HMC with MOG target distrubtion using eager execution in tensorflow

### Imports

In [None]:
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from l2hmc_eager import dynamics_eager as l2hmc
from l2hmc_eager.neural_nets import *
from utils.distributions import GMM, gen_ring
from utils.jacobian import _map, jacobian

%autoreload 2

In [None]:
tf.enable_eager_execution()
tfe = tf.contrib.eager

In [None]:
tfe = tf.contrib.eager

In [None]:
def train_one_iter(dynamics, x, optimizer, 
                   loss_fn=l2hmc.compute_loss, global_step=None):
    loss, grads, out, accept_prob = l2hmc.loss_and_grads(
        dynamics, x, loss_fn=loss_fn
    )
    optimizer.apply_gradients(
        zip(grads, dynamics.trainable_variables), global_step=global_step
    )
    return loss, out, accept_prob

In [None]:
def distribution_arr(x_dim, n_distributions):
    """Create array describing likelihood of drawing from distributions."""
    if n_distributions > x_dim:
        pis = [1. / n_distributions] * n_distributions
        pis[0] += 1 - sum(pis)
        return np.array(pis)
    if x_dim == n_distributions:
        big_pi = round(1.0 / n_distributions, x_dim)
        pis = n_distributions * [big_pi]
        return np.array(pis)
    else:
        big_pi = (1.0 / n_distributions) - x_dim * 1E-16
        pis = n_distributions * [big_pi]
        small_pi = (1. - sum(pis)) / (x_dim - n_distributions)
        pis.extend((x_dim - n_distributions) * [small_pi])
        return np.array(pis)

### MoG Model

In [None]:
x_dim = 2 
num_distributions = 2
sigma = 0.05
axis = 0
centers = 1

means = np.zeros((x_dim, x_dim))
means[::2, axis] = centers
means[1::2, axis] = - centers

cov_mtx = sigma * np.eye(x_dim)
sigmas = np.array([cov_mtx] * x_dim)

pis = distribution_arr(x_dim, num_distributions)
mog_distribution = GMM(means, sigmas, pis)

In [None]:
mog_potential_fn = mog_distribution.get_energy_function()

mog_dynamics = l2hmc.Dynamics(x_dim=2, 
                              minus_loglikelihood_fn=mog_potential_fn,
                              n_steps=2,
                              eps=0.5,
                              np_seed=1)

In [None]:
train_iters = 1000
eval_iters = 20 
n_samples = 200
record_loss_every = 10 
save_steps = 100 

global_step = tf.train.get_or_create_global_step()
global_step.assign(1)
learning_rate = tf.train.exponential_decay(1e-3, global_step, 
                                           1000, 0.96, staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate)
checkpointer = tf.train.Checkpoint(optimizer=optimizer,
                                   dynamics=mog_dynamics,
                                   global_step=global_step)

log_dir = '../../tf_eager_log/mog_model/run_3/'
summary_writer = tf.contrib.summary.create_file_writer(log_dir)
# if restore:
#     latest_path = tf.train.latest_checkpoint(train_dir)
#     checkpointer.restore(latest_path)
#     print("Restored latest checkpoint at path:\"{}\"".format(latest_path))
#     sys.stdout.flush()
# if not restore:
#     if use_defun:
#         loss_fn = tfe.function(l2hmc.compute_loss)
#     else:
loss_fn = tfe.defun(l2hmc.compute_loss)
samples = tf.random_normal(shape=[n_samples, x_dim])
for i in range(1, train_iters + 1):
    loss, samples, accept_prob = train_one_iter(
        mog_dynamics,
        samples,
        optimizer,
        loss_fn=loss_fn,
        global_step=global_step
    )
    
    if i % record_loss_every == 0:
        print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format(
            i, loss.numpy(), accept_prob.numpy().mean()
        ))
        with summary_writer.as_default():
            with tf.contrib.summary.always_record_summaries():
                _ = tf.contrib.summary.scalar("Training loss", 
                                              loss, 
                                              step=global_step)
                
    if i % save_steps == 0:
        saved_path = checkpointer.save(file_prefix=os.path.join(log_dir,
                                                                "ckpt"))
        print(f"Saved checkpoint to: {saved_path}")
        
print("Training complete.")
sys.stdout.flush()

In [None]:
_samples = tf.random_normal(shape=[n_samples, x_dim])
samples_history = []
for i in range(100):
    samples_history.append(_samples.numpy())
    _, _, _, _samples = mog_dynamics.apply_transition(_samples)
samples_history = np.array(samples_history)

In [None]:
samples_history.shape

In [None]:
target_samples

In [None]:
target_samples = mog_distribution.get_samples(500)
fig, ax = plt.subplots()
ax.plot(target_samples[:,0], target_samples[:,1], color='C0', alpha=0.5, marker='o', ls='')
ax.plot(samples_history[:, 0, 0], samples_history[:, 0, 1], color='C1', alpha=0.75, ls='-')
plt.show()

In [None]:
samples = mog_distribution.get_samples(10)

In [None]:
samples

In [None]:
fig, ax = plt.subplots()
ax.plot(samples[:,0], samples[:,1], color='C0', marker='o', ls='')
#ax.plot(samples_history[:, 0, 0], samples_history[:, 0, 1], color='C1', alpha=0.75, ls='-')
plt.show()