# Toy Distributions for L2HMC

## Setup

### Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Global imports
import sys
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import horovod.tensorflow as hvd
hvd.init()

from collections import namedtuple

# append parent directory to `sys.path`
# to load from modules in `../l2hmc-qcd/`
module_path = os.path.join('..')
if module_path not in sys.path:
    sys.path.append(module_path)

# Local imports
from utils.attr_dict import AttrDict
from utils.training_utils import train_dynamics
from dynamics.config import DynamicsConfig
from dynamics.base_dynamics import BaseDynamics
from dynamics.generic_dynamics import GenericDynamics
from network.config import LearningRateConfig
from config import (State, NetWeights, MonteCarloStates,
                    BASE_DIR, BIN_DIR, TF_FLOAT)

from utils.distributions import (plot_samples2D, contour_potential,
                                 two_moons_potential, sin_potential,
                                 sin_potential1, sin_potential2)

#sns.set_palette('bright')

In [None]:
sns.set(context="notebook")
plt.style.use("/Users/saforem2/.config/matplotlib/stylelib/molokai.mplstyle")

## Helper functions

In [None]:
x = np.arange(0, 2*np.pi, 0.02)
y_arr = [i * np.sin(x) for i in range(9)]
fig, ax = plt.subplots()
for y in y_arr:
    _ = ax.plot(x, y)

In [None]:
from network.config import NetworkConfig, LearningRateConfig

def identity(x):
    return x

def get_dynamics(flags):
    """Return `GenericDynamics` object, initialized from `flags`."""
    config = DynamicsConfig(eps=flags.eps,
                            num_steps=flags.num_steps,
                            aux_weight=flags.aux_weight,
                            loss_scale=0.1,
                            hmc=flags.hmc,
                            eps_fixed=flags.eps_fixed,
                            model_type=flags.model_type)


    net_config = NetworkConfig(units=flags.units,
                               dropout_prob=flags.dropout_prob,
                               name=flags.network_name,
                               activation_fn=flags.activation_fn)

    lr_config = LearningRateConfig(flags.lr_init,
                                   decay_steps=flags.decay_steps,
                                   decay_rate=flags.decay_rate,
                                   warmup_steps=flags.warmup_steps)

    dynamics = GenericDynamics(params=flags,
                               config=config,
                               lr_config=lr_config,
                               normalizer=identity,
                               network_config=net_config,
                               potential_fn=POTENTIAL_FN,
                               name=MODEL_TYPE)
    
    return dynamics

In [None]:
import utils.file_io as io
from utils.distributions import contour_potential

#%matplotlib inline

def plot_chains(dirs, x_arr, potential_fn, label=None, cmap='rainbow'):
    figs_dir = os.path.join(dirs.log_dir, 'figures')
    io.check_else_make_dir(figs_dir)

    x_arr = tf.convert_to_tensor(x_arr).numpy()

    for chain in range(4):
        fig, ax = plt.subplots()
        xy = np.array((x_arr[1000:, chain, 0], x_arr[1000:, chain, 1]))
        #sns.kdeplot(*xy, ax=ax)
        #grid = xy.reshape(2, -1).T
        #Z = np.exp(-POTENTIAL_FN(grid))
        #_ = ax.contourf(xy[0], xy[1], Z.reshape(xy[0].shape, xy[1].shape), cmap='inferno')
        #xlim = np.abs(np.floor(np.min(xy[0]))) + 1
        #ylim = np.abs(np.floor(np.max(xy[1]))) + 1
        xlim = 5
        ylim = 5
        _ = contour_potential(POTENTIAL_FN, ax=ax, cmap=cmap, xlim=xlim, ylim=ylim)
        _ = ax.plot(*xy, alpha=0.3, mew=0.9, ls='', marker='+',
                    color='white', label='l2hmc samples')
        #_ = ax.legend(markerscale=5., loc='best')
        _ = ax.set_xlim((-xlim, xlim))
        _ = ax.set_ylim((-ylim, ylim))
        out_file = os.path.join(figs_dir, f'trained_samples_chain{chain}.png')
        print(f'Saving figure to: {out_file}')
        _ = plt.savefig(out_file, dpi=400, bbox_inches='tight')
        plt.show()

In [None]:
def plot_density_estimation(potential_fn, x_l2hmc, x_hmc,
                            title=None, cmap=None, num_plots=5):
    def _format_arr(x):
        x = np.array(x)
        n = x.shape[0]
        therm = 2 * n // 10  # Drop first 20% of samples (thermalization)
        return x[therm:]
    
    x_l2hmc = _format_arr(x_l2hmc)
    x_hmc = _format_arr(x_hmc)
    
    for idx in range(num_plots):
        fig, axes = plt.subplots(ncols=3, figsize=(12, 4))
        _ = contour_potential(potential_fn, title=title, ax=axes[0], cmap=cmap)
        _ = sns.kdeplot(x_l2hmc[:, idx, 0], x_l2hmc[:, idx, 1],
                        shade=True, cmap=cmap, ax=axes[1])
        _ = sns.kdeplot(x_hmc[:, idx, 0], x_hmc[:, idx, 1],
                        shade=True, cmap=cmap, ax=axes[2])
        _ = axes[1].set_title('L2HMC samples')
        _ = axes[2].set_title('HMC samples')
        plt.tight_layout()
        
    return fig, axes

### Plot examples of (toy) target distributions:

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2)
axes = axes.flatten()
names = ['two_moons', 'sin', 'sin_hard', 'sin_harder']
potentials = [two_moons_potential, sin_potential, sin_potential1, sin_potential2]
potentials_and_axes = zip(potentials, axes)
potentials_dict = {}
for idx, (p_fn, ax) in enumerate(zip(potentials, axes)):
    _ = contour_potential(p_fn, ax, title=f'{names[idx]}', cmap='rocket')
    ax.facecolor = '#1c1c1c'
    fig = plt.gcf()
    fig.facecolor = '#1c1c1c'
    potentials_dict[names[idx]] = p_fn
    

plt.grid(True)
plt.show()

In [None]:
from utils.distributions import GaussianFunnel, RoughWell

funnel = GaussianFunnel()
funnel_potential = funnel.get_energy_function()
fig, ax = plt.subplots()
_ = contour_potential(funnel_potential, ax=ax, title=f'Gaussian Funnel Potential')
plt.show()

rough_well = RoughWell(dim=2, eps=0.1, easy=True)
rough_well_hard = RoughWell(dim=2, eps=0.1, easy=False)

rw_potential = rough_well.get_energy_function()
rwh_potential = rough_well_hard.get_energy_function()

fig, axes = plt.subplots(nrows=1, ncols=2)
axes = axes.flatten()
ax0 = contour_potential(rw_potential, ax=axes[0], cmap='rocket', title='Rough Well Potential')
ax1 = contour_potential(rwh_potential, ax=axes[1], cmap='rocket', title='Rough Well Potential (Hard)')
ax0.set_aspect('equal')
ax1.set_aspect('equal')

potentials_dict.update({
    'funnel': funnel_potential,
    'rough_well': rw_potential,
    'rough_well_hard': rwh_potential,
})

In [None]:
from utils.distributions import GaussianMixtureModel, meshgrid
import tensorflow_probability as tfp

tfd = tfp.distributions
floatx = tf.keras.backend.floatx()


def make_gmm_model(mus, sigmas, pis):
    def to_tensors(x):
        return (tf.convert_to_tensor(i, dtype=floatx) for i in x)
    
    mus, sigmas, pis = to_tensors([mus, sigmas, pis])
    gmm = tfd.Mixture(
        cat=tfd.Categorical(probs=pis),
        components=[
            tfd.MultivariateNormalDiag(loc=m, scale_diag=s)
            for m, s in zip(mus, sigmas)
        ]
    )
    
    def potential(x):
        return -1. * gmm.log_prob(x)
    
    #model = GaussianMixtureModel(mus, sigmas, pis) 
    #potential_fn = lambda x: -1. * model.dist.log_prob(x)
    #return model, potential_fn
    
    return gmm, potential

Gaussian mixture models:

 1. 2-Component mixture: 
 $$x \sim p(x) \equiv \frac{1}{2}\mathcal{N}(\vec{x}_{0}, \Sigma_{0}) + \frac{1}{2}\mathcal{N}(\vec{x}_{1}, \Sigma_{1})$$
 2. $4\times 4$ Lattice of Gaussians: 
 $$x\sim\mathcal{N}(\vec{x}_{ij}, \Sigma_{ij})$$

In [None]:
# ==== Mixture of two components
mus = [(-1., 0), (1., 0)]
sigmas = [0.1 * np.ones(2) for _ in range(len(mus))]
pis = len(mus) * [1. / len(mus)]

gmm, gmm_potential = make_gmm_model(mus, sigmas, pis)


# ==== 4x4 Lattice of Gaussians
# xy locations of each component
mus = [(-2, -2), (-2, -1), (-2, +0), (-2, +1), (-2, +2),
       (-1, -2), (-1, -1), (-1, +0), (-1, +1), (-1, +2),
       (+0, -2), (+0, -1), (+0, +0), (+0, +1), (+0, +2),
       (+1, -2), (+1, -1), (+1, +0), (+1, +1), (+1, +2),
       (+2, -2), (+2, -1), (+2, +0), (+2, +1), (+2, +2)]

sigmas = [0.1 * np.ones(2) for _ in range(len(mus))]
pis = len(mus) * [1. / len(mus)]

gmm_latt, gmm_latt_potential = make_gmm_model(mus, sigmas, pis)


potentials_dict.update({
    'gmm': gmm_potential,
    'lattice_of_gaussians': gmm_latt_potential,
})


# ==== Plot contours of both potentials
fig, axes = plt.subplots(ncols=2)
axes = axes.flatten()

ax0 = contour_potential(gmm_potential, ax=axes[0],
                        cmap='rocket', xlim=1.5, ylim=0.5,
                        title='Gaussian Mixture Model')

ax1 = contour_potential(gmm_latt_potential, ax=axes[1],
                        cmap='rocket', xlim=2.7777775, ylim=2.75,
                        title='Lattice of Gaussians')
    
_ = [ax.set_aspect('equal') for ax in axes]

## Define parameters of the model and target distribution

In [None]:
import utils.file_io as io
import datetime

LOGS_DIR = os.path.abspath('../../logs')

# DEFINE THE TARGET DISTRIBUTION
MODEL_TYPE = 'two_moons'
POTENTIAL_FN = potentials_dict[MODEL_TYPE]

now = datetime.datetime.now()
date_str = now.strftime('%Y-%m-%d')
timestamp = now.strftime('%Y-%m-%d-%H%M%S')

log_dir = os.path.join(LOGS_DIR, f'{MODEL_TYPE}', date_str)
if os.path.isdir(log_dir):
    log_dir = os.path.join(log_dir, timestamp)
    
io.check_else_make_dir(log_dir)


flags = AttrDict({
    'profiler': False,
    'xdim': 2,
    'eps': 0.01,
    'aux_weight': 0.,
    'loss_scale': 0.1,
    'batch_size': 256,
    'num_steps': 10,
    'beta_init': 1.,
    'beta_final': 1.,
    'compile': True,
    'hmc_steps': 0,
    'lr_init': 1e-3,
    'train_steps': 5000,
    'clip_val': 1.0,
    'decay_rate': 0.96,
    'save_steps': 1000,
    'logging_steps': 100,
    'warmup_steps': 1000,
    'print_steps': 1,
    'units': [128, 128],
    'hmc': False,
    'eps_fixed': False,
    'model_type': MODEL_TYPE,
    'network_name': 'GenericNetwork',
    'dropout_prob': 0.,
    'activation_fn': tf.nn.relu,
    'log_dir': log_dir,
})

flags.decay_steps = flags.train_steps // 5
#flags.warmup_steps = flags.train_steps // 10

## Train

### Start by training HMC to find optimal step-size $\varepsilon$ and thermalized config $x_{\mathrm{therm}}$ 

In [None]:
from network.config import NetworkConfig, LearningRateConfig

flags.hmc_steps = 1000
flags.restore = False

x_shape = (flags.batch_size, flags.xdim)
x = tf.random.normal(shape=x_shape, dtype=TF_FLOAT)

net_config = NetworkConfig(units=flags.units,
                           dropout_prob=flags.dropout_prob,
                           name=flags.network_name,
                           activation_fn=flags.activation_fn)

lr_config = LearningRateConfig(flags.lr_init,
                     decay_steps=flags.decay_steps,
                     decay_rate=flags.decay_rate,
                     warmup_steps=flags.warmup_steps)

# TRAIN HMC
if flags.hmc_steps > 0:
    hmc_flags = AttrDict({k: v for k, v in flags.items()})
    #hmc_flags.train_steps = hmc_flags.pop('hmc_steps')
    hmc_flags.train_steps = 5000
    hmc_flags.logging_steps = hmc_flags.train_steps // 20
    hmc_flags.beta_final = hmc_flags.beta_init
    hmc_flags.compile = True
    hmc_config = DynamicsConfig(eps=hmc_flags.eps,
                                num_steps=hmc_flags.num_steps,
                                hmc=True,
                                eps_fixed=flags.eps_fixed,
                                model_type=MODEL_TYPE)
    hmc_dynamics = GenericDynamics(params=hmc_flags,
                                   config=hmc_config,
                                   lr_config=lr_config,
                                   network_config=net_config,
                                   potential_fn=POTENTIAL_FN,
                                   name=MODEL_TYPE)
    hmc_dirs = io.setup_directories(hmc_flags, 'training_hmc')
    x, train_data = train_dynamics(hmc_dynamics, hmc_flags, dirs=hmc_dirs, x=x)
    
    output_dir = os.path.join(hmc_dirs.train_dir, 'outputs')
    train_data.save_data(output_dir)
    #flags.eps = hmc_dynamics.eps.numpy()

### Create `GenericDynamics` object

In [None]:
dynamics = get_dynamics(flags)

In [None]:
s = dynamics.xnet.summary()

In [None]:
dynamics.optimizer._

In [None]:
s

In [None]:
tf.keras.utils.plot_model(dynamics.xnet, show_shapes=True)

### Train L2HMC sampler using HMC sampler as starting point

In [None]:
flags.restore = False

In [None]:
dirs = io.setup_directories(flags)
x = tf.random.normal(dynamics.x_shape)
flags.train_steps = 2000
x, train_data = train_dynamics(dynamics, flags, dirs=dirs, x=x)

In [None]:
#dynamicspath = os.path.join(dirs.log_dir, 'training', 'dynamics.h5')
dynamicspath1 = os.path.join(dirs.log_dir, 'training', 'dynamics')
#print(f'Saving `dynamics` to: {dynamicspath}')
print(f'Saving `dynamics` to: {dynamicspath1}')

#dynamics.save(dynamicspath)
dynamics.save(dynamicspath1)
      

In [None]:
xnetpath = os.path.join(dirs.log_dir, 'training', 'dynamics_xnet.h5')
vnetpath = os.path.join(dirs.log_dir, 'training', 'dynamics_vnet.h5')
print(f'Saving `dynamics.xnet` to : {xnetpath}')
print(f'Saving `dynamics.vnet` to : {vnetpath}')
dynamics.xnet.save(xnetpath)
dynamics.vnet.save(vnetpath)


xnetpath1 = os.path.join(dirs.log_dir, 'training', 'dynamics_xnet1')
vnetpath1 = os.path.join(dirs.log_dir, 'training', 'dynamics_vnet1')
print(f'Saving `dynamics.xnet` to : {xnetpath1}')
print(f'Saving `dynamics.vnet` to : {vnetpath1}')
dynamics.xnet.save(xnetpath1)
dynamics.vnet.save(vnetpath1)

In [None]:
xnet_copy = tf.keras.models.load_model(xnetpath)
vnet_copy = tf.keras.models.load_model(vnetpath)

xnet_copy1 = tf.keras.models.load_model(xnetpath1)
vnet_copy1 = tf.keras.models.load_model(vnetpath1)

In [None]:
#x = tf.random.normal(dynamics.x_shape)
#v = tf.random.normal(dynamics.x_shape)
x = tf.ones(dynamics.x_shape)
v = tf.ones(dynamics.x_shape)
t = dynamics._get_time(0, tile=tf.shape(x)[0])

In [None]:
for l1, l2 in zip(dynamics.xnet.layers, xnet_copy.layers):
    print(f'{l1.name}, {l2.name}')
    if l1.weights != [] and l2.weights != []:
        for w1, w2 in zip(l1.weights, l2.weights):
            print(f'dw = {tf.reduce_sum(w1 - w2)}')
    #print(f'  Original:\n w = {l1.weights}\n   copy:\n w = {l2.weights}\n')
    #print(f'  Original:\n w = {l1.weights}\n   copy:\n w = {l2.weights}\n')
    #print(f'  Original:\n b = {l1.weights[1].numpy()}\n   copy:\n b = {l2.weights[1].numpy()}\n')

In [None]:
x = tf.ones(dynamics.x_shape)
v = tf.ones(dynamics.x_shape)
t = dynamics._get_time(0, tile=tf.shape(x)[0])

s, t, q = dynamics.xnet((x, v, t), training=False)
s_, t_, q_ = xnet_copy((x, v, t), training=False)
s1_, t1_, q1_ = xnet_copy1((x, v, t), training=False)

np.allclose(s.numpy(), s_.numpy())
np.allclose(s.numpy(), s1_.numpy())
np.allclose(s_.numpy(), s1_.numpy())

np.allclose(t.numpy(), t_.numpy())
np.allclose(t.numpy(), t1_.numpy())
np.allclose(t_.numpy(), t1_.numpy())

np.allclose(q.numpy(), q_.numpy())
np.allclose(q.numpy(), q1_.numpy())
np.allclose(q_.numpy(), q1_.numpy())


In [None]:
s.numpy()

In [None]:
s_.numpy()

In [None]:
from network.layers import ScaledTanhLayer


stl = ScaledTanhLayer(128, 1., name='scale1')

In [None]:
config = stl.get_config()

In [None]:
config

## Run inference

In [None]:
import utils.file_io as io

ckpt = tf.train.Checkpoint(model=dynamics, optimizer=dynamics.optimizer)
manager = tf.train.CheckpointManager(ckpt, dirs.ckpt_dir, max_to_keep=5)
if manager.latest_checkpoint:
    io.log(f'INFO:Checkpoint restored from: {manager.latest_checkpoint}')
    ckpt.restore(manager.latest_checkpoint)
    current_step = dynamics.optimizer.iterations.numpy()

In [None]:
from utils.inference_utils import run_dynamics
import utils.file_io as io

flags.log_dir = dirs.log_dir
flags.beta = flags.beta_final

summary_dir = os.path.join(flags.log_dir, 'inference', 'summaries')
io.check_else_make_dir(summary_dir)
writer = tf.summary.create_file_writer(summary_dir)
writer.set_as_default()

flags.run_steps = 5000
run_data, x, x_arr = run_dynamics(dynamics, flags, save_x=True)

writer.flush()
writer.close()

In [None]:
os.listdir(dirs.log_dir)

In [None]:
os.listdir(os.path.join(dirs.log_dir, 'training'))

In [None]:
help(keras.fun)

In [None]:
from dynamics.config import DynamicsConfig

hmc_flags = AttrDict(dict(flags))
hmc_flags.logging_steps = hmc_flags.train_steps // 20
hmc_flags.beta_final = hmc_flags.beta_init
hmc_flags.compile = True
hmc_config = DynamicsConfig(eps=0.15,
                            num_steps=hmc_flags.num_steps,
                            hmc=True,
                            eps_fixed=flags.eps_fixed,
                            model_type=MODEL_TYPE)
hmc_dynamics = GenericDynamics(params=hmc_flags,
                               config=hmc_config,
                               lr_config=dynamics.lr_config,
                               network_config=dynamics.net_config,
                               potential_fn=POTENTIAL_FN,
                               name=MODEL_TYPE)
hmc_dynamics._parse_net_weights(NetWeights(0., 0., 0., 0., 0., 0.))
#hmc_dirs = setup_directories(hmc_flags, 'training_hmc')

summary_dir_hmc = os.path.join(hmc_flags.log_dir, 'inference', 'summaries')
io.check_else_make_dir(summary_dir_hmc)
writer_hmc = tf.summary.create_file_writer(summary_dir_hmc)
writer_hmc.set_as_default()
hmc_flags.run_steps = 5000
x_init = tf.random.normal(x.shape)
run_data_hmc, x_hmc, x_arr_hmc = run_dynamics(hmc_dynamics, hmc_flags, save_x=True, x=x_init)
    
writer_hmc.flush()
writer_hmc.close()

In [None]:
plot_density_estimation(dynamics.potential_fn, x_arr, x_arr,
                        num_plots=2, title=MODEL_TYPE, cmap='viridis')

In [None]:
from utils.distributions import contour_potential
xl2hmc = np.array(x_arr)
xhmc = np.array(x_arr_hmc)

for idx in range(5):
    fig, axes = plt.subplots(ncols=3, figsize=(12, 4))
    _ = contour_potential(POTENTIAL_FN, title='Rough Well (true)', ax=axes[0])
    _ = sns.kdeplot(xl2hmc[:, idx, 0], xl2hmc[:, idx, 1],
                    shade=True, cmap='inferno', ax=axes[1])
    _ = sns.kdeplot(xhmc[:, idx, 0], xhmc[:, idx, 1],
                    shade=True, cmap='inferno', ax=axes[2])
    _ = axes[1].set_title('L2HMC samples')
    _ = axes[2].set_title('HMC samples')
    plt.tight_layout()