In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import matplotlib.pyplot as plt 
from IPython import display

In [None]:
print(jax.devices())
!nvidia-smi

In [None]:
key = jax.random.PRNGKey(42)

n = 20
dim = 2
batch = 8192
beta = 10.0
lr = 0.001
hidden_sizes = [64, 64]

# Classical Coulomb gas

We'd like to study thermodynamic property of the classical Coulomb gas, whose Hamiltonian reads

$$H= \sum_{i<j} \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}_j|} + \sum_i  \boldsymbol{x}_i^2 . $$
The second term is a harmonic trapping potential. It makes our story easier (no need to consider periodic bondary condition or Ewald sum for long range interaction.)

The way to go is to minimize the variationial free energy with respect to a variational probability density $p(\boldsymbol{x})$

$$\mathcal{L} = \mathbb{E}_{\boldsymbol{x} \sim p(\boldsymbol{x})} \left [\frac{1}{\beta}\ln p(\boldsymbol{x}) +  H(\boldsymbol{x}) \right] \ge -\frac{1}{\beta} \ln Z, $$ 
where $Z = \int d \boldsymbol{x} e^{-\beta H}$ and $\beta$ is the inverse temperature. The equality holds when $p(\boldsymbol{x}) = e^{-\beta H}/Z$, i.e., we achieve the exact solution. 

First thing first, here is the energy function

In [None]:
#################################################################################### 
def energy_fn(x, n, dim):
    i, j = jnp.triu_indices(n, k=1)
    rij = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    return jnp.sum(x**2) + jnp.sum(1/rij)

batch_energy = jax.vmap(energy_fn, (0, None, None), 0)

The probabilistic model we will use is the flow model, which involves a change of variables. The coordinate $x$ is expressed as a function of $z$:
$$
x = g_{\theta}(z),
$$
where $\theta$ represents the parameters of the neural network.

The probabilistic transformation can then be written as:
$$
p_{\theta}(g_{\theta}(z)) = q(z) \left|\frac{\partial z}{\partial g_{\theta}(z)}\right|.
$$
This can be simplified to:
$$
p(x) = q(z) \left|\frac{\partial z}{\partial x}\right|,
$$
which implies:
$$
\ln p(x) = \ln q(z) - \ln \left|\frac{\partial x}{\partial z}\right|.
$$

In [None]:
#################################################################################### 
class NeuralNetwork(hk.Module):
    def __init__(self, hidden_sizes, n, dim):
        super().__init__()
        self.hidden_sizes = hidden_sizes
        self.n = n
        self.dim = dim
        self.network = hk.nets.MLP(hidden_sizes + [n*dim], 
                                   activation=jax.nn.softplus, 
                                   w_init=hk.initializers.TruncatedNormal(0.1), 
                                   b_init=hk.initializers.TruncatedNormal(1.0))

    #========== MLP ==========
    def __call__(self, z):
        z_flatten = jnp.reshape(z, (self.n*self.dim, ))
        x_flatten = self.network(z_flatten)
        x = x_flatten.reshape((self.n, self.dim))
        return x

#################################################################################### 
def make_flow(hidden_sizes, n, dim):
    
    def forward_fn(z):
        model = NeuralNetwork(hidden_sizes, n, dim)
        return model(z)
    
    flow = hk.transform(forward_fn)
    return flow

Let's have a look at the model.
This is the number of parameters in the neural network:

In [None]:
flow = make_flow(hidden_sizes, n, dim)
params = flow.init(key, jnp.zeros((n, dim)))

from jax.flatten_util import ravel_pytree
raveled_params, _ = ravel_pytree(params)

print("parameters in the flow model: %d" % raveled_params.size, flush=True)

In [None]:
def make_logp(flow):
    
    def logp(z, params):
        n, dim = z.shape
        x = flow.apply(params, None, z)
        logqz = jnp.sum(jax.scipy.stats.norm.logpdf(z))
        
        z_flatten = z.reshape(-1)
        flow_flatten = lambda z: flow.apply(params, None, z.reshape(n, dim)).reshape(-1)
        jac = jax.jacfwd(flow_flatten)(z_flatten)
        _, logjacdet = jnp.linalg.slogdet(jac)
        
        logpx = logqz - logjacdet
        return logpx, logqz, logjacdet, x
    
    return logp

logp_novmap = make_logp(flow)
logp = jax.vmap(logp_novmap, (0, None), (0, 0, 0, 0))

x is the output of the neural network. The shape of x should be equal to z. The shape of logp should be equal to batchsize:

In [None]:
z = jax.random.normal(key, (batch, n, dim))
print("z.shape:", z.shape)

logpx, logqz, logjacdet, x = logp(z, params)
print("x.shape:", x.shape)
print("batch logp(x):", logpx)
print("batch logq(z):", logqz)
print("batch logjacdet(dz/dx):", logjacdet)

## Loss function

The gradient of the objective function is:

$$
\nabla_{\theta} \mathcal{L} 
= \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}(\boldsymbol{z})} 
\left[   \nabla_{\theta} f(g( \boldsymbol{z}) )\right]
,
$$ 
where $f (\boldsymbol{x}) =\frac{1}{\beta}\ln p(\boldsymbol{x}) +  H(\boldsymbol{x})$.
It can be simplified as: 
$$
\nabla_{\theta} \mathcal{L} 
= \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}(\boldsymbol{z})} 
\left[   \nabla_{\theta} 
\left( - \frac{1}{\beta} \ln \left|\frac{\partial g_{\theta}(z)}{\partial z}\right| + H(g(z))
\right)\right].
$$ 
This is known as the Reparametrization gradient estimator. See https://arxiv.org/abs/1906.10652 for more details. 

In [None]:
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)

In [None]:
def make_loss(logp, batch, n, dim, beta):
    
    def observable_and_lossfn(params, z):
        #logpx, logqz, logjacdet, x = logp(z, params)
        logpx, _, logjacdet, x = logp(z, params)
        
        Eloc = batch_energy(x, n, dim)
        Floc = logpx / beta + Eloc
        
        F_mean, F_std = Floc.mean(), jnp.std(Floc)/jnp.sqrt(batch)
        E_mean, E_std = Eloc.mean(), jnp.std(Eloc)/jnp.sqrt(batch)
        S_mean, S_std = -logpx.mean(), jnp.std(-logpx)/jnp.sqrt(batch)
        observable = (F_mean, F_std, E_mean, E_std, S_mean, S_std, x)

        gradf_theta = jnp.mean( -logjacdet / beta + Eloc )   
        return gradf_theta, observable
    
    return observable_and_lossfn

observable_and_lossfn = make_loss(logp, batch, n, dim, beta)
value_and_gradfn = jax.value_and_grad(observable_and_lossfn, argnums=0, has_aux=True)

### Training

Here is the training loop. During training we monitor the density and loss histroy. 

In [None]:
@jax.jit
def update(params, opt_state, z):
    datas, grad_params = value_and_gradfn(params, z)
    updates, opt_state = optimizer.update(grad_params, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, datas

In [None]:
loss_history = []
for i in range(5000):
    key, subkey = jax.random.split(key)
    
    z = jax.random.normal(key, (batch, n, dim))
    params, opt_state, datas = update(params, opt_state, z)
    
    F_mean, F_std, E_mean, E_std, S_mean, S_std, x = datas[1]
    loss_history.append(jnp.array([F_mean, F_std, E_mean, E_std, S_mean, S_std]))
    print("epoch: %04d    F: %.6f (%.6f)    E: %.6f (%.6f)    S: %.6f (%.6f)"
          %(i, F_mean, F_std, E_mean, E_std, S_mean, S_std))

    plot_x = jnp.reshape(x, (batch*n, dim)) 
    plot_z = jnp.reshape(z, (batch*n, dim))
    
    display.clear_output(wait=True)

    fig = plt.figure(figsize=(18, 6), dpi=300)
    plt.title("epoch: %04d    F: %.6f (%.6f)    E: %.6f (%.6f)    S: %.6f (%.6f)"
          %(i, F_mean, F_std, E_mean, E_std, S_mean, S_std), fontsize=16)
    plt.axis('off')
    #====== plot x ======
    plt.subplot(1, 3, 1)
    H, xedges, yedges = np.histogram2d(plot_x[:, 0], plot_x[:, 1], bins=100, 
                                       range=((-4, 4), (-4, 4)), density=True)
    plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]), cmap="inferno")
    plt.xlim([-4, 4])
    plt.ylim([-4, 4])

    #====== plot z ======
    plt.subplot(1, 3, 2)
    H, xedges, yedges = np.histogram2d(plot_z[:, 0], plot_z[:, 1], bins=100, 
                                       range=((-4, 4), (-4, 4)), density=True)
    plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]), cmap="inferno")
    plt.xlim([-4, 4])
    plt.ylim([-4, 4])

    #====== plot loss ======
    plt.subplot(1, 3, 3)
    y = np.reshape(np.array(loss_history), (-1, 6))
    plt.errorbar(np.arange(i+1), y[:, 0], yerr=y[:, 1], marker='o', capsize=8)
    plt.xlabel('epochs')
    plt.ylabel('variational free energy')
    plt.pause(0.0001)
    
print(loss_history[-1])

Do you see structure emerges from training? 
Yes! It is called [Wigner molecule](https://en.wikipedia.org/wiki/Wigner_crystal). Physicists had already studied [the ground state of Wigner crystals for such small clusters](https://www.sciencedirect.com/science/article/abs/pii/S0749603683710268). the interesting result is that for the case of six electrons, there are five electrons arranged around a single electron at the center, which is actually what we get.