In [7]:
import jax.random as rnd
import jax.numpy as jnp
from jax import lax, vmap, grad, jit
from functools import partial
import jax


In [35]:

key = rnd.PRNGKey(1)
L = 3.
x = rnd.normal(key, (10,3))
z = jnp.linalg.norm(x, axis=-1)
w = rnd.normal(key, (3, 1))
mask = z < L
print(jnp.sum(mask))

def linear(x):
    z = x @ w
    return jnp.tanh(z).squeeze()

def wf(x):
    z = jnp.linalg.norm(w.squeeze() * x, axis=-1)
    out = linear(x) * jnp.exp(-z)
    print(z.shape, out.shape, x.shape, linear(x).shape, jnp.exp(-z).shape)
    mask = (z < (L / 2.)).astype(x.dtype)
    # assert jnp.sum(mask) > 0.
    out = mask * out
    out = jnp.tanh(linear(x) * out)
    return out.squeeze()

vwf = vmap(wf, in_axes=(0,))

9


In [37]:
# vwf version

def local_kinetic_energy_i(wf):
    """
    FUNCTION SLIGHTLY ADAPTED FROM DEEPMIND JAX FERMINET IMPLEMTATION
    https://github.com/deepmind/ferminet/tree/jax

    """
    def _lapl_over_f(walkers):
        n_walkers = walkers.shape[0]
        walkers = walkers.reshape(n_walkers, -1)
        n = walkers.shape[-1]
        eye = jnp.eye(n, dtype=walkers.dtype)[None, ...].repeat(n_walkers, axis=0)
        wf_new = lambda walkers: wf(walkers).sum()
        grad_f = jax.grad(wf_new, argnums=0)
        # grad_f_closure = lambda y: grad_f(params, y, d0s)  # ensuring the input can be just x

        def _body_fun(i, val):
            # primal is the first order evaluation
            # tangent is the second order
            primal, tangent = jax.jvp(grad_f, (walkers,), (eye[..., i],))
            print(val.shape, primal.shape)
            return val + primal[:, i]**2 + tangent[:, i]

        # from lower to upper
        # (lower, upper, func(int, a) -> a, init_val)
        # this is like functools.reduce()
        # val is the previous  val (initialised to 0.0)
        return -0.5 * lax.fori_loop(0, n, _body_fun, jnp.zeros(walkers.shape[0]))

    return _lapl_over_f

full = local_kinetic_energy_i(vwf)
full(x)

() () (3,) () ()
(10,) (10, 3)


DeviceArray([-0.07822731, -0.        ,  0.79436773, -0.        ,
             -0.        , -0.73799956,  0.16120832,  0.3664742 ,
              0.64229894, -1.0483482 ], dtype=float32)

In [16]:
# second order derivs

def second(wf):
    def _lapl_over_f(x):
        n = x.shape[0]
        print(n)
        eye = jnp.eye(n, dtype=x.dtype)
        print(eye.shape)
        grad_f = jax.grad(wf)
        # grad_f_closure = lambda y: grad_f(y)  # ensuring the input can be just x

        def _body_fun(i, val):
            # primal is the first order evaluation
            # tangent is the second order
            primal, tangent = jax.jvp(grad_f, (x,), (eye[:, i],))
            print(i, primal.shape, tangent.shape)
            return val + primal[i]**2 + tangent[i]
        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

    return vmap(_lapl_over_f, in_axes=(0,))

grad = second(wf)
grad(x)

3
(3, 3)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)> (3,) (3,)


DeviceArray([-0.07822731, -0.        ,  0.79436773, -0.        ,
             -0.        , -0.73799956,  0.16120833,  0.3664742 ,
              0.64229894, -1.0483482 ], dtype=float32)

In [None]:
n = x.shape[1]
print(n)
eye = jnp.eye(n, dtype=x.dtype)
print(eye.shape)
grad_f = jax.grad(wf)
# grad_f_closure = lambda y: grad_f(y)  # ensuring the input can be just x

i = 0
val = 0.0
primal, tangent = jax.jvp(grad_f, (x,), (eye[None, :, i].repeat(x.shape[0], axis=0),))
print(i)
print(primal, '\n', tangent, '\n', mask)

d = val + primal[:, i]**2 + tangent[:, i]

# _body_fun(0, 0.0)


# -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)