In [None]:
import jax 
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np 
from functools import partial
import matplotlib.pyplot as plt

In [None]:
print(jax.devices())
!nvidia-smi


We consider a classical Coulomb gas with the Hamiltonian 

$$H= \sum_{i<j} \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}_j|} + \sum_i  \boldsymbol{x}_i^2 , $$
where the two terms are Coulomb interaction and harmonic trapping potential respectively.

In [None]:
def energy_fn(x, n, dim):
    i, j = jnp.triu_indices(n, k=1)
    rij = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    return jnp.sum(x**2) + jnp.sum(1/rij)

We can obtain the gradient function via `jax.grad`

In [None]:
grad_fn = jax.grad(energy_fn)

Let's have a look at particles and force

In [None]:
n, dim = 20, 2
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (n, dim)) # random particle position
f = grad_fn(x, n, dim)

fig = plt.figure(figsize=(6, 6), dpi=300)
plt.scatter(x[:, 0], x[:, 1])
plt.quiver(x[:, 0], x[:, 1], f[:, 0], f[:, 1], color='red', scale=1)

plt.xlim([-3, 3])
plt.ylim([-3, 3])

## Ground state 

We want to find the minimal energy configuration

$$x^\ast  = \mathrm{argmin}_{x} H(x) .$$

For that, we carry out a gradient descent

$$x \leftarrow x - \eta \frac{\partial H }{\partial x}$$ 



In [None]:
def optimize(x, steps=500, eta=1e-2):
    for _ in range(steps):
        x = x - eta * grad_fn(x, n, dim)
    return x

Run the optimization, and have a look at the result!

In [None]:
x = optimize(x)

In [None]:
fig = plt.figure(figsize=(6, 6))

f = grad_fn(x, n, dim)
plt.scatter(x[:, 0], x[:, 1])
plt.quiver(x[:, 0], x[:, 1], f[:, 0], f[:, 1], color='red', scale=1)

plt.xlim([-3, 3])
plt.ylim([-3, 3])

We can actually run a batch of optimizations in parallel with `vmap`

In [None]:
batchsize = 64
x = jax.random.normal(key, (batchsize, n, dim))
x = jax.vmap(optimize)(x)

In [None]:
fig = plt.figure(figsize=(6, 6))

for b in range(batchsize):
    plt.scatter(x[b, :, 0], x[b, :, 1], alpha=0.5, color='b')
plt.xlim([-3, 3])
plt.ylim([-3, 3])

## Finite temperature

We want to sample configuation from the equlibrium Boltzman distribution 

$$ x\sim \frac{e^{-\beta H(x)}}{Z}. $$ 


For that, we will use the Metropolis Monte Carlo algorithms. We randomly move the particles, and accept the move with probability 

$$ A(x \rightarrow x^\prime ) = \min \left[ 1, \frac{e^{-\beta H(x^\prime)}}{e ^{-\beta H(x)}} \right] $$ 






In [None]:
@partial(jax.jit, static_argnums=0)
def mcmc(logp_fn, x_init, key, mc_steps, mc_width):
    """
        Markov Chain Monte Carlo sampling algorithm.

    INPUT:
        logp_fn: callable that evaluate log-probability of a batch of configuration x.
            The signature is logp_fn(x), where x has shape (batch, n, dim).
        x_init: initial value of x, with shape (batch, n, dim).
        key: initial PRNG key.
        mc_steps: total number of Monte Carlo steps.
        mc_width: size of the Monte Carlo proposal.

    OUTPUT:
        x: resulting batch samples, with the same shape as `x_init`.
    """
    def step(i, state):
        x, logp, key, num_accepts = state
        key, key_proposal, key_accept = jax.random.split(key, 3)
        
        x_proposal = x + mc_width * jax.random.normal(key_proposal, x.shape)
        logp_proposal = logp_fn(x_proposal)

        ratio = jnp.exp((logp_proposal - logp))
        accept = jax.random.uniform(key_accept, ratio.shape) < ratio

        x_new = jnp.where(accept[:, None, None], x_proposal, x)
        logp_new = jnp.where(accept, logp_proposal, logp)
        num_accepts += accept.sum()
        return x_new, logp_new, key, num_accepts
    
    logp_init = logp_fn(x_init)

    x, logp, key, num_accepts = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0.))
    accept_rate = num_accepts / (mc_steps * x.shape[0])
    return x, accept_rate

In [None]:
@partial(jax.vmap, in_axes=(0, None, None, None))
def logp(x, n, dim, beta):
    return -beta * energy_fn(x, n, dim)

beta = 10.0 # inverse temperature
batchsize = 8192
mc_steps = 100 
mc_width = 0.02

x = jax.random.normal(key, (batchsize, n, dim))
energy_fn_vmap = jax.vmap(energy_fn, (0, None, None), 0)

for i in range(20):
    key, subkey = jax.random.split(key)
    x, acc = mcmc(lambda x: logp(x, n, dim, beta), x, subkey, mc_steps, mc_width)
    energy = jnp.mean(energy_fn_vmap(x, n, dim))
    print("%.2d    %.6f    %.6f    %.6f" %(i, acc, mc_width, energy))
    
    if acc > 0.525: mc_width *= 1.05
    if acc < 0.475: mc_width *= 0.95

In [None]:
x = jnp.reshape(x, (batchsize*n, dim)) 
#density plot
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], 
                                   bins=100, 
                                   range=((-4, 4), (-4, 4)),
                        density=True)

fig = plt.figure(figsize=(6, 6))
plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]),
               cmap="inferno")

# Create a network

In [None]:
def make_network(layer_sizes):
    
    def init(key, scale=1e-2):
        params = []
        for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            weight_key, bias_key = jax.random.split(key)
            weight = scale * jax.random.normal(weight_key, (n_in, n_out))
            bias = scale * jax.random.normal(bias_key, (n_out,))
            params.append((weight, bias))
        return params

    def relu(x):
        return jnp.maximum(0, x)

    def apply(params, x):
        for w, b in params[:-1]:
            x = relu(jnp.dot(x, w) + b)
        final_w, final_b = params[-1]
        return jnp.dot(x, final_w) + final_b

    return init, apply
     

In [None]:

layer_sizes = [784, 128, 64, 10] # IN: 784 pixels, OUT: 10 classes
init_fn, apply_fn = make_network(layer_sizes)
key = jax.random.PRNGKey(42)
params = init_fn(key)

In [None]:

from jax.flatten_util import ravel_pytree
ravel_pytree(params)[0].size # 784*128 + 128*64 + 64*10 + 128 + 64 + 10 