In [1]:
%load_ext autoreload
%autoreload 2
    
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd

import torch as tc
from torch import nn
tc.set_default_dtype(tc.float64)

import numpy as np
from tutorial_functions import compare

In [2]:
from ops.wf.fnstar.ferminet import model, create_masks
from ops.wf.fnstar.parameters import initialise_params, count_mixed_features
from ops.vmc.utils import create_atom_batch

from pytorch.model import fermiNet
from pytorch.utils import from_np, compute_local_energy

def update_state_dict(model_tc, params):
    tmp = []
    for k, value in params.items():
        print(k)
        
        if k == 'intermediate':
            for intermediate in zip(*params[k]):
                for ps in intermediate: 
                    tmp.append(ps)

        elif k == 'envelopes':
            order = ('linear', 'sigma', 'pi')
            for spin in (0, 1):
                for layer in order:
                    ps = params[k][layer][spin]
                    tmp.append(ps)
                    
        else:
            tmp.append(value)
                
    sd = model_tc.state_dict()
    for (k, val), p in zip(sd.items(), tmp):
#         print(val.shape, p.shape)
        assert val.shape == p.shape
        sd[k] = from_np(p)
        
    model_tc.load_state_dict(sd, strict=True)
    
#     sd = model_tc.state_dict()
#     for (k, val), p in zip(sd.items(), tmp):
#         compare(val, p)
        
    return model_tc
    

In [3]:
n_electrons = 4
n_atom = 1
n_up = 2
n_down = n_electrons - n_up
n_layers = 2
n_sh = 20
n_ph = 10
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, 2)
n_det = 5
n_samples = 100

re = rnd.normal(key, (n_samples, n_electrons, 3))
ra = create_atom_batch(rnd.normal(key, (n_atom, 3)), n_samples)
z_atoms = jnp.array([n_electrons])[None, ...]

print(re.shape, ra.shape)

(100, 4, 3) (100, 1, 3)


In [4]:
model_tc = fermiNet(n_layers, n_electrons,n_up,n_atom,n_sh,n_ph,n_det,ra)

masks = create_masks(n_atom, n_electrons, n_up, n_layers, n_sh, n_ph)
params = initialise_params(key, n_atom, n_up, n_down, n_layers, n_sh, n_ph, n_det)

model_tc = update_state_dict(model_tc, params)

from functools import partial
pmodel = partial(model, r_atoms=ra, masks=masks, n_up=n_up, n_down=n_down)

split0
s0
p0
intermediate
envelopes


In [16]:
# compare the forward pass
from ops.wf.fnstar_wdiag.ferminet import model, create_masks
from ops.wf.fnstar_wdiag.parameters import initialise_params, count_mixed_features
from ops.vmc.utils import create_atom_batch

# import importlib
# importlib.reload(ops.wf.fnstar_wdiag.ferminet)

masks = create_masks(n_atom, n_electrons, n_up, n_layers, n_sh, n_ph)
pmodel = partial(model, r_atoms=ra, masks=masks, n_up=n_up, n_down=n_down)
lp = pmodel(params, re)


[[ 0.76519381 -0.99294267 -0.97528894 -0.99947825]
 [-0.99935745  0.76519381 -0.99511248 -0.99982506]
 [-0.91016879 -0.81575175  0.76519381 -0.92986409]
 [-0.5846823  -0.01619707  0.81325687  0.76519381]]


In [6]:
compute_local_energy(from_np(ra), from_np(re), from_np(z_atoms), model_tc)


torch.Size([100, 2, 20]) torch.Size([100, 2, 1, 3])


tensor([ 1.8668e-01, -6.0279e+01, -4.8631e+01,  8.2714e+01, -3.0050e+01,
         5.6943e+01, -2.3383e+01, -1.8957e+02, -2.8564e+01, -2.5042e+01,
        -7.8011e+00, -2.5458e+01, -2.0699e+01, -3.7841e+02,  2.0051e+01,
        -9.1541e+01, -2.1140e+02, -1.1171e+02,  1.1981e+01,  1.9095e+01,
        -1.8624e+02,  9.2402e+01, -5.7140e+01, -1.3954e+01, -6.5256e+00,
        -6.0259e+01, -1.0805e+03, -7.4044e+00, -2.9448e+02, -5.4135e+01,
         2.4819e+00, -4.6979e+02, -1.3265e+01,  2.8792e+01,  2.3383e+01,
        -2.1644e+01, -1.9716e+01,  8.9349e-01, -4.0708e+01, -1.5699e+00,
        -2.5161e+02, -1.4999e+02, -8.3877e+01, -6.8211e+00, -2.1897e+02,
        -4.9404e+00,  1.6622e+01, -3.4159e+01,  1.8216e+01, -3.5473e+01,
        -2.6014e+01, -8.4958e+01, -3.0104e+01, -1.5076e+01, -4.2794e+02,
        -8.6820e+01, -3.0682e+01, -7.6872e+01,  1.8090e+00,  1.0626e+02,
        -9.2846e+00, -4.9585e+01,  1.5761e+01, -9.2635e+01, -9.0257e+00,
         1.2810e+00, -9.1573e+01, -6.8936e+01, -3.2

In [7]:
# compare grads
def sumpmodel(pmodel):
    def _sum_pmodel(params, r_electrons):
        logpsi = pmodel(params, r_electrons)
        return jnp.sum(logpsi)
    return _sum_pmodel

spmodel = sumpmodel(pmodel)

grad_model = jax.grad(spmodel, argnums=1)
grad = grad_model(params, re)
 
print(grad.shape)
    

(100, 16, 10)
(100, 16, 10)
(100, 4, 3)


In [8]:
from jax import lax

def local_kinetic_energy(f):

    def _lapl_over_f(params, x):  # this is recalled everytime
        x = x.reshape(x.shape[0], -1)
        n = x.shape[1]
        eye = jnp.eye(n)
        eye = jnp.repeat(eye[None, ...], x.shape[0], 0)
        grad_f = jax.grad(f, argnums=1)
        grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x
    
        def _body_fun(i, val):
            # primal is the first order evaluation
            # tangent is the second order
            primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))
            return val + primal[i]**2 + tangent[i]
    
        # from lower to upper
        # (lower, upper, func(int, a) -> a, init_val)
        # this is like functools.reduce()
        # val is the previous  val (initialised to 0.0)
        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)
  
    return _lapl_over_f


def batched_cdist_l2(x1, x2):
    x1_norm = jnp.sum(x1 ** 2, axis=-1, keepdims=True)
    x2_norm = jnp.sum(x2 ** 2, axis=-1, keepdims=True)
    cdist = jnp.sqrt(jnp.swapaxes(x2_norm, -1, -2) + x1_norm - 2 * x1 @ jnp.swapaxes(x2, -1, -2))
    return cdist

def compute_potential_energy(r_atom, r_electron, z_atom):
    n_samples, n_electron = r_electron.shape[:2]
    n_atom = r_atom.shape[1]

    potential_energy = jnp.zeros(n_samples)

    e_e_dist = batched_cdist_l2(r_electron, r_electron)  # electron - electron distances
    potential_energy += jnp.sum(jnp.tril(1. / e_e_dist, k=-1), axis=(-1, -2))

    a_e_dist = batched_cdist_l2(r_atom, r_electron)  # atom - electron distances
    potential_energy -= jnp.einsum('ba,bae->b', z_atom, 1./a_e_dist)

    if n_atom > 1:  # THIS HAS NOT BEEN VERIFIED
        a_a_dist = batched_cdist_l2(r_atom, r_atom)
        weighted_a_a = jnp.einsum('bn,bm,bnm->bnm', z_atom, z_atom, 1/a_a_dist)
        unique_a_a = weighted_a_a[:, jnp.tril(np.ones((n_atom, n_atom), dtype=bool), -1)]  # this will not work
        potential_energy += jnp.sum(unique_a_a, axis=-1)

    return potential_energy
             


def local_energy(f, r_atoms, z_atoms):
    ke = local_kinetic_energy(f)

    def _e_l(params, r_electrons):
    
        potential = compute_potential_energy(r_atoms, r_electrons, z_atoms)
        kinetic = ke(params, r_electrons)
        return potential + kinetic

    return _e_l

def sumpmodel(pmodel):
    def _sum_pmodel(params, r_electrons):
        logpsi = pmodel(params, r_electrons)
        return jnp.sum(logpsi)
    return _sum_pmodel

spmodel = sumpmodel(pmodel)
energyf = local_energy(spmodel, ra, z_atoms)
energyf(params, re)

(100, 16, 10)
(100, 16, 10)
(100, 16, 10)
(100, 16, 10)


TypeError: body_fun output and input must have identical types, got
(ShapedArray(int64[], weak_type=True), ShapedArray(int64[], weak_type=True), ShapedArray(float64[12]))
and
(ShapedArray(int64[], weak_type=True), ShapedArray(int64[], weak_type=True), ShapedArray(float64[])).

In [None]:




# def _lapl_over_f(params, x):  # this is recalled everytime
#         x = x.reshape(x.shape[0], -1)
#         n = x.shape[1]
#         eye = jnp.eye(n)
#         eye = jnp.repeat(eye[None, ...], x.shape[0], 0)
#         grad_f = jax.grad(f, argnums=1)
#         grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x
    
#         def _body_fun(i, val):
#             # primal is the first order evaluation
#             # tangent is the second order
#             primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))
#             return val + primal[i]**2 + tangent[i]
    
#         # from lower to upper
#         # (lower, upper, func(int, a) -> a, init_val)
#         # this is like functools.reduce()
#         # val is the previous  val (initialised to 0.0)
#         return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)
  
#     return _lapl_over_f


key = rnd.PRNGKey(1)
x = rnd.normal(key, (25,))

def function(x):
    x = x.reshape((5, 5))
    n = x.shape[0]
    eye = jnp.eye(n, dtype=bool)
    y = x[eye]
    return jnp.sum(y)

z = function(x)
grad1f = jax.grad(function)
g1 = gradf(x)
grad2f = jax.grad(gradf)
g2 = grad2f(x)

primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))


In [None]:
# compare the energy
def local_kinetic_energy(f):

  def _lapl_over_f(params, x):  # this is recalled everytime
    n = x.shape[0]
    eye = jnp.eye(n)
    grad_f = jax.grad(f, argnums=1)
    grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x

    def _body_fun(i, val):
      # primal is the first order evaluation
      # tangent is the second order 
      primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[i],))
      return val + primal[i]**2 + tangent[i]

    # from lower to upper
    # (lower, upper, func(int, a) -> a, init_val)
    # this is like functools.reduce()
    # val is the previous  val (initialised to 0.0)
    return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

  return _lapl_over_f


ke = local_kinetic_energy(pmodel)


In [None]:
lp


In [None]:
# strip the diagonal 
# https://stackoverflow.com/questions/46736258/deleting-diagonal-elements-of-a-numpy-array

# Approach #1 masking

# A[~np.eye(A.shape[0],dtype=bool)].reshape(A.shape[0],-1)

# # Approach #2

# # Using the regular pattern of non-diagonal elements that could be traced with broadcasted additions with range arrays -

# m = A.shape[0]
# idx = (np.arange(1,m+1) + (m+1)*np.arange(m-1)[:,None]).reshape(m,-1)
# out = A.ravel()[idx]

# # Approach #3 (Strides Strikes!)

# # Abusing the regular pattern of non-diagonal elements from previous approach, we can introduce np.lib.stride_tricks.as_strided and some slicing help, like so -

# m = A.shape[0]
# strided = np.lib.stride_tricks.as_strided
# s0,s1 = A.strides
# out = strided(A.ravel()[1:], shape=(m-1,m), strides=(s0+s1,s1)).reshape(m,-1)

def skip_diag_masking(A):
    return A[~np.eye(A.shape[0],dtype=bool)].reshape(A.shape[0],-1)

def skip_diag_broadcasting(A):
    m = A.shape[0]
    idx = (np.arange(1,m+1) + (m+1)*np.arange(m-1)[:,None]).reshape(m,-1)
    return A.ravel()[idx]

def skip_diag_strided(A):
    m = A.shape[0]
    strided = np.lib.stride_tricks.as_strided
    s0,s1 = A.strides
    return strided(A.ravel()[1:], shape=(m-1,m), strides=(s0+s1,s1)).reshape(m,-1)

# Timings -

A = np.random.randint(11,99,(5000,5000))

%timeit skip_diag_masking(A)
%timeit skip_diag_broadcasting(A)
%timeit skip_diag_strided(A)
#      ...: 
# 10 loops, best of 3: 56.1 ms per loop
# 10 loops, best of 3: 82.1 ms per loop
# 10 loops, best of 3: 32.6 ms per loop


    