In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.example_libraries import optimizers
import numpy as np
import haiku as hk
import optax
import matplotlib.pyplot as plt 
from IPython import display

In [None]:
print(jax.devices())
!nvidia-smi

In [None]:
key = jax.random.PRNGKey(0)

n = 20
dim = 2
lr = 0.01

$$H= \sum_{i<j} \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}_j|} + \sum_i  \boldsymbol{x}_i^2 . $$

In [None]:
def energy_fn(x, n, dim):
    i, j = jnp.triu_indices(n, k=1)
    rij = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    return jnp.sum(x**2) + jnp.sum(1/rij)

batch_energy = jax.vmap(energy_fn, (0, None, None), 0)

In [None]:
key, key2 = jax.random.split(key)
x = jax.random.normal(key, (n, dim))
energy_and_grad = jax.value_and_grad(energy_fn)


In [None]:
def energy_optimize(x):
    energy, grad = energy_and_grad(x, n, dim)
    x = x - lr * grad
    return x, energy, -grad*0.1

In [None]:
energy_history = []
for i in range(500):

    x, e, g = energy_optimize(x)
    energy_history.append([e]) 
    display.clear_output(wait=True)

    fig = plt.figure(figsize=(12, 6), dpi = 300)
    plt.title("epoch: %.3d    E: %.6f" % (i, e), fontsize=16)
    plt.axis("off")
    plt.subplot(1, 2, 1)
    plt.scatter(x[:, 0], x[:, 1], s=10)
    plt.quiver(x[:, 0], x[:, 1], g[:, 0], g[:, 1], color='red', scale=1)
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])

    plt.subplot(1, 2, 2)
    y = np.array(energy_history)
    plt.errorbar(np.arange(i+1), y, marker='o', capsize=8)
    plt.xlabel('epochs')
    plt.ylabel('energy')
    plt.pause(0.01)

print(e)