In [32]:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import jax
from jax import grad, lax
from jax import numpy as jnp, random as rnd


key = rnd.PRNGKey(1)

def create_local_kinetic_energy(wf):
    ''' kinetic energy function which works on a vmapped wave function '''

    def _lapl_over_f(params, walkers):
        n_walkers = walkers.shape[0]
        
        walkers = walkers.reshape(n_walkers, -1)
        
        n = walkers.shape[-1]
        
        eye = jnp.eye(n, dtype=walkers.dtype)[None, ...].repeat(n_walkers, axis=0)
        
        wf_new = lambda walkers: wf(params, walkers)
        grad_f = jax.grad(wf_new, holomorphic=True)

        def _body_fun(i, val):
            primal, tangent = jax.jvp(grad_f, (walkers,), (eye[..., i],))  # primal / tangent first / second order derivatives
            return val + (primal[:, i]**2).squeeze() + (tangent[:, i]).squeeze()

        # from lower to upper (lower, upper, func(int, a) -> a, init_val)
        # val is the previous  val (initialised to 0.0)
        return -0.5 * lax.fori_loop(0, n, _body_fun, jnp.zeros(walkers.shape[0], dtype=jnp.complex64))

    return _lapl_over_f


In [33]:
n_particle = 5
n_walkers = 1

key, subkey = rnd.split(key)
r = rnd.uniform(subkey, (n_walkers, n_particle, 3)).astype(jnp.complex64)

key, subkey = rnd.split(key)
w = rnd.uniform(subkey, (3, n_particle)).astype(jnp.complex64)

key, subkey = rnd.split(key)
w_i = rnd.uniform(subkey, (3, n_particle)).astype(jnp.complex64)

def psi(params, r):
    n_walkers = r.shape[0]
    w, w_i = params

    if len(r.shape) == 2:
        r = r.reshape(n_walkers, -1, 3)
    
    real = jnp.exp(- jnp.matmul(r, w))
    imag = 1j * jnp.exp(- jnp.matmul(r, w_i))

    orbitals = real + imag
    print(orbitals.shape)

    psi = jnp.linalg.det(orbitals)
    print(psi.shape)
    return psi.sum()

first_order = grad(psi, holomorphic=True)
laplacian = create_local_kinetic_energy(psi)


In [34]:
first_order((w, w_i), r)
laplacian((w, w_i), r)

(1, 5, 5)
(1,)
(1, 5, 5)
(1,)


DeviceArray([-0.00047366-9.840645e-05j], dtype=complex64)

In [18]:
r

DeviceArray([[[0.24035037+0.j, 0.52228856+0.j, 0.71770823+0.j],
              [0.10579169+0.j, 0.10497069+0.j, 0.87328124+0.j],
              [0.56231284+0.j, 0.03035879+0.j, 0.64574075+0.j],
              [0.31272447+0.j, 0.3981874 +0.j, 0.21055484+0.j],
              [0.77974796+0.j, 0.6391331 +0.j, 0.61742246+0.j]]],            dtype=complex64)