In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import matplotlib.pyplot as plt 
from IPython import display

In [None]:
print(jax.devices())
!nvidia-smi

In [None]:
key = jax.random.PRNGKey(0)

n = 6
dim = 2
batch = 1024
beta = 10.0

mc_steps = 100
mc_width = 0.1
iters = 20

$$H= \sum_{i<j} \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}_j|} + \sum_i  \boldsymbol{x}_i^2 . $$

In [None]:
def energy_fn(x, n, dim):
    i, j = jnp.triu_indices(n, k=1)
    rij = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    return jnp.sum(x**2) + jnp.sum(1/rij)

batch_energy = jax.vmap(energy_fn, (0, None, None), 0)

In [None]:
from functools import partial
@partial(jax.jit, static_argnums=0)
def mcmc(logp_fn, x_init, key, mc_steps, mc_width):
    """
        Markov Chain Monte Carlo sampling algorithm.
    """
    def step(i, state):
        x, logp, key, num_accepts = state
        key, key_proposal, key_accept = jax.random.split(key, 3)
        
        x_proposal = x + mc_width * jax.random.normal(key_proposal, x.shape)
        logp_proposal = logp_fn(x_proposal)

        ratio = jnp.exp((logp_proposal - logp))
        accept = jax.random.uniform(key_accept, ratio.shape) < ratio

        x_new = jnp.where(accept[:, None, None], x_proposal, x)
        logp_new = jnp.where(accept, logp_proposal, logp)
        num_accepts += accept.sum()
        return x_new, logp_new, key, num_accepts
    
    logp_init = logp_fn(x_init)

    x, logp, key, num_accepts = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0.))
    batch = x.shape[0]
    accept_rate = num_accepts / (mc_steps * batch)
    return x, accept_rate

In [None]:
@partial(jax.vmap, in_axes=(None, 0, None, None))
def logp(beta, x, n, dim):
    return -beta * energy_fn(x, n, dim)

In [None]:
x = jax.random.normal(key, (batch, n, dim))

for ii in range(iters):
      
    key, key_mcmc = jax.random.split(key, 2)
    
    logp_fn = lambda x: logp(beta, x, n, dim)
    x, acc = mcmc(logp_fn, x, key_mcmc, mc_steps, mc_width)
    energy = jax.vmap(energy_fn, (0, None, None), 0)(x, n, dim)
    E = jnp.mean(energy)
    
    print("step: %.3d    E: %.6f    acc: %.3f    mc_width: %.3f"
      %(ii, E, acc, mc_width))
    
    if acc > 0.525: mc_width *= 1.05
    if acc < 0.475: mc_width *= 0.95

In [None]:
plot_x = jnp.reshape(x, (batch*n, dim)) 
#density plot
fig = plt.figure(figsize=(4, 4), dpi=300)
H, xedges, yedges = np.histogram2d(plot_x[:, 0], plot_x[:, 1], 
                    bins=200, range=((-4, 4), (-4, 4)), density=True)

plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]),
               cmap="inferno")