In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [3]:
import numpy as np
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import *
from pytorch.pretraining import Pretrainer
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import model, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features

  h5py.get_config().default_file_mode = 'a'


ImportError: cannot import name 'model' from 'ops.wf.ferminet' (/home/xmax/nn_ansatz/src/ops/wf/ferminet.py)

In [None]:
def batched_cdist_l2(x1, x2):
    x1_sq = jnp.sum(x1 ** 2, axis=-1, keepdims=True)
    x2_sq = jnp.sum(x2 ** 2, axis=-1, keepdims=True)
    cdist = jnp.sqrt(jnp.swapaxes(x1_sq, -1, -2) + x2_sq \
                 - jnp.sum(2 * jnp.expand_dims(x1, axis=1) * jnp.expand_dims(x2, axis=2), axis=-1))
    return cdist

def compute_potential_energy_jax(r_atom, r_electron, z_atom):
    n_samples, n_electron = r_electron.shape[:2]
    n_atom = r_atom.shape[0]

    potential_energy = jnp.zeros(n_samples)

    e_e_dist = batched_cdist_l2(r_electron, r_electron)  # electron - electron distances
    potential_energy += jnp.sum(jnp.tril(1. / e_e_dist, k=-1), axis=(-1, -2))

    a_e_dist = batched_cdist_l2(r_atom, r_electron)  # atom - electron distances
    potential_energy -= jnp.einsum('a,bae->b', z_atom, 1./a_e_dist)

    if n_atom > 1:  # THIS HAS NOT BEEN VERIFIED
        a_a_dist = batched_cdist_l2(r_atom, r_atom)
        weighted_a_a = jnp.einsum('bn,bm,bnm->bnm', z_atom, z_atom, 1/a_a_dist)
        unique_a_a = weighted_a_a[:, jnp.tril(np.ones((n_atom, n_atom), dtype=bool), -1)]  # this will not work
        potential_energy += jnp.sum(unique_a_a, axis=-1)

    return potential_energy

def sumpmodel(pmodel):
    def _sum_pmodel(params, r_electrons):
        logpsi = pmodel(params, r_electrons)
        return jnp.sum(logpsi)
    return _sum_pmodel

In [None]:
n_el = 4
n_atom = 1
n_up = 2
n_down = n_el - n_up
n_layers = 2
n_sh = 20
n_ph = 10
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, 2)
n_det = 5
n_walkers = 20

walkers = rnd.normal(key, (n_walkers, n_el, 3))
r_atoms = rnd.normal(key, (n_atom, 3))
z_atoms = jnp.array([n_el])

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype)
mol = Molecule(r_atoms, z_atoms, n_el)

"""
In this phase
- Integrate the drop diagonal to the jax model 
- compare logpsi and energy computations for the torch and jax implementations
- get the energy computation working for the jax implementation
- port the pytorch pretrainer to jax
- establish the jax samplers
- test the jax pretrainer
"""

In [None]:
# setup the models
masks = create_masks(n_atom, n_el, n_up, n_layers, n_sh, n_ph)
params = initialise_params(key, n_atom, n_up, n_down, n_layers, n_sh, n_ph, n_det)
pmodel = partial(model, r_atoms=r_atoms, masks=masks, n_up=n_up, n_down=n_down)
vmodel = jax.vmap(pmodel, in_axes=(None, 0))

model_tc = fermiNet(mol_tc, n_det=n_det, n_sh=n_sh, n_ph=n_ph, diagonal=False)
model_tc = update_state_dict(model_tc, params)

In [None]:
logpsi_tc = model_tc(walkers_tc)
logpsi = vmodel(params, walkers)
compare(logpsi_tc, logpsi)

In [None]:
def local_kinetic_energy(f):

    def _lapl_over_f(params, x):  # this is recalled everytime
        x = x.reshape(-1)
        n = x.shape[0]
        eye = jnp.eye(n, dtype=x.dtype)
        grad_f = jax.grad(f, argnums=1)
        grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x
    
        def _body_fun(i, val):
            # primal is the first order evaluation
            # tangent is the second order
            primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))
            print(primal.dtype, tangent.dtype)
            print(primal.shape, tangent.shape)
            return val + primal[i]**2 + tangent[i]
    
        # from lower to upper
        # (lower, upper, func(int, a) -> a, init_val)
        # this is like functools.reduce()
        # val is the previous  val (initialised to 0.0)
        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)
  
    return _lapl_over_f

lap = jax.jit(jax.vmap(local_kinetic_energy(pmodel), in_axes=(None, 0)))

In [None]:

ep_tc = compute_potential_energy(walkers_tc, r_atoms_tc, z_atoms_tc)
ek_tc =  laplacian(model_tc, walkers_tc)
ek_tc = -0.5 * (ek_tc[0].sum(-1) + ek_tc[1].sum(-1))

ek = lap(params, walkers)
ep = compute_potential_energy_jax(r_atoms, walkers, z_atoms)

In [None]:
compare(ek_tc, ek)
compare(ep_tc, ep)

print(ek, '\n', ek_tc)

In [None]:

def laplacian(params, wf, x):  # this is recalled everytime
    x = x.reshape(-1)
    n = x.shape[0]
    eye = jnp.eye(n, dtype=x.dtype)
    grad_f = jax.grad(wf, argnums=1)
    grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x

    def _body_fun(i, val):
        # primal is the first order evaluation
        # tangent is the second order
        primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))
        print(primal.dtype, tangent.dtype)
        print(primal.shape, tangent.shape)
        return val + primal[i]**2 + tangent[i]

    # from lower to upper
    # (lower, upper, func(int, a) -> a, init_val)
    # this is like functools.reduce()
    # val is the previous  val (initialised to 0.0)
    return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

lap = jax.jit(jax.vmap(laplacian, in_axes=(None, None, 0)))


In [None]:
lap(params, pmodel, walkers)

In [None]:
# compare the forward pass
from ops.wf.fnstar_wdiag.ferminet import model, create_masks
from ops.wf.fnstar_wdiag.parameters import initialise_params, count_mixed_features
from ops.vmc.utils import create_atom_batch

# import importlib
# importlib.reload(ops.wf.fnstar_wdiag.ferminet)

masks = create_masks(n_atom, n_electrons, n_up, n_layers, n_sh, n_ph)
pmodel = partial(model, r_atoms=ra, masks=masks, n_up=n_up, n_down=n_down)
lp = pmodel(params, re)


In [None]:
compute_local_energy(from_np(ra), from_np(re), from_np(z_atoms), model_tc)


In [None]:




# def _lapl_over_f(params, x):  # this is recalled everytime
#         x = x.reshape(x.shape[0], -1)
#         n = x.shape[1]
#         eye = jnp.eye(n)
#         eye = jnp.repeat(eye[None, ...], x.shape[0], 0)
#         grad_f = jax.grad(f, argnums=1)
#         grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x
    
#         def _body_fun(i, val):
#             # primal is the first order evaluation
#             # tangent is the second order
#             primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))
#             return val + primal[i]**2 + tangent[i]
    
#         # from lower to upper
#         # (lower, upper, func(int, a) -> a, init_val)
#         # this is like functools.reduce()
#         # val is the previous  val (initialised to 0.0)
#         return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)
  
#     return _lapl_over_f


key = rnd.PRNGKey(1)
x = rnd.normal(key, (25,))

def function(x):
    x = x.reshape((5, 5))
    n = x.shape[0]
    eye = jnp.eye(n, dtype=bool)
    y = x[eye]
    return jnp.sum(y)

z = function(x)
grad1f = jax.grad(function)
g1 = gradf(x)
grad2f = jax.grad(gradf)
g2 = grad2f(x)

primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[..., i],))


In [None]:
# compare the energy
def local_kinetic_energy(f):

  def _lapl_over_f(params, x):  # this is recalled everytime
    n = x.shape[0]
    eye = jnp.eye(n)
    grad_f = jax.grad(f, argnums=1)
    grad_f_closure = lambda y: grad_f(params, y)  # ensuring the input can be just x

    def _body_fun(i, val):
      # primal is the first order evaluation
      # tangent is the second order 
      primal, tangent = jax.jvp(grad_f_closure, (x,), (eye[i],))
      return val + primal[i]**2 + tangent[i]

    # from lower to upper
    # (lower, upper, func(int, a) -> a, init_val)
    # this is like functools.reduce()
    # val is the previous  val (initialised to 0.0)
    return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

  return _lapl_over_f


ke = local_kinetic_energy(pmodel)


In [None]:
lp


In [None]:
# strip the diagonal 
# https://stackoverflow.com/questions/46736258/deleting-diagonal-elements-of-a-numpy-array

# Approach #1 masking

# A[~np.eye(A.shape[0],dtype=bool)].reshape(A.shape[0],-1)

# # Approach #2

# # Using the regular pattern of non-diagonal elements that could be traced with broadcasted additions with range arrays -

# m = A.shape[0]
# idx = (np.arange(1,m+1) + (m+1)*np.arange(m-1)[:,None]).reshape(m,-1)
# out = A.ravel()[idx]

# # Approach #3 (Strides Strikes!)

# # Abusing the regular pattern of non-diagonal elements from previous approach, we can introduce np.lib.stride_tricks.as_strided and some slicing help, like so -

# m = A.shape[0]
# strided = np.lib.stride_tricks.as_strided
# s0,s1 = A.strides
# out = strided(A.ravel()[1:], shape=(m-1,m), strides=(s0+s1,s1)).reshape(m,-1)

def skip_diag_masking(A):
    return A[~np.eye(A.shape[0],dtype=bool)].reshape(A.shape[0],-1)

def skip_diag_broadcasting(A):
    m = A.shape[0]
    idx = (np.arange(1,m+1) + (m+1)*np.arange(m-1)[:,None]).reshape(m,-1)
    return A.ravel()[idx]

def skip_diag_strided(A):
    m = A.shape[0]
    strided = np.lib.stride_tricks.as_strided
    s0,s1 = A.strides
    return strided(A.ravel()[1:], shape=(m-1,m), strides=(s0+s1,s1)).reshape(m,-1)

# Timings -

A = np.random.randint(11,99,(5000,5000))

%timeit skip_diag_masking(A)
%timeit skip_diag_broadcasting(A)
%timeit skip_diag_strided(A)
#      ...: 
# 10 loops, best of 3: 56.1 ms per loop
# 10 loops, best of 3: 82.1 ms per loop
# 10 loops, best of 3: 32.6 ms per loop


    

In [None]:
def compute_potential_energy(r_atom, r_electron, z_atom):
    n_samples, n_electron = r_electron.shape[:2]
    n_atom = r_atom.shape[1]

    potential_energy = jnp.zeros(n_samples)

    e_e_dist = batched_cdist_l2(r_electron, r_electron)  # electron - electron distances
    potential_energy += jnp.sum(jnp.tril(1. / e_e_dist, k=-1), axis=(-1, -2))

    a_e_dist = batched_cdist_l2(r_atom, r_electron)  # atom - electron distances
    potential_energy -= jnp.einsum('a,bae->b', z_atom, 1./a_e_dist)

    if n_atom > 1:  # THIS HAS NOT BEEN VERIFIED
        a_a_dist = batched_cdist_l2(r_atom, r_atom)
        weighted_a_a = jnp.einsum('bn,bm,bnm->bnm', z_atom, z_atom, 1/a_a_dist)
        unique_a_a = weighted_a_a[:, jnp.tril(np.ones((n_atom, n_atom), dtype=bool), -1)]  # this will not work
        potential_energy += jnp.sum(unique_a_a, axis=-1)

    return potential_energy

compute_potential_energy(r_atoms, walkers, z_atoms).shape