In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [2]:
import jax.random as rnd
import jax.numpy as jnp
key = rnd.PRNGKey(1)
x = rnd.normal(key, (100,2))
jnp.max(x, axis=-1)

DeviceArray([ 1.2737546 ,  2.4117248 ,  1.3882742 ,  0.7035917 ,
              2.193741  ,  0.16973034,  1.3588555 , -0.5840599 ,
              0.60256624,  1.2763157 , -0.31434137,  1.4645936 ,
              2.126335  ,  0.78721184,  0.50604457, -1.538333  ,
              0.73629296,  1.7001914 ,  1.496785  ,  0.36924478,
              1.0543514 ,  0.9789985 ,  0.35151017, -0.066497  ,
             -0.11481832, -0.18533254, -0.29030418,  0.7003464 ,
             -0.76301354,  1.0972972 ,  0.69392604,  2.119694  ,
             -0.09748547,  0.50608397,  0.7576622 ,  0.41617453,
              0.32975954,  0.47700652, -0.27752584,  0.38763425,
             -0.5255122 ,  1.0126301 ,  1.5729288 ,  0.66961694,
              1.9571755 ,  0.65423685,  0.01874596,  0.6467748 ,
             -0.22355916,  0.27673158,  1.6970062 , -0.5874497 ,
              0.84318596, -0.10330604,  0.43957442, -0.03459455,
              0.6331224 , -0.0070661 , -0.37678435, -0.09135703,
             -0.2196901 ,

In [3]:
import numpy as np
from ops.utils import compare
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax, jit, vmap
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import *
from pytorch.pretraining import Pretrainer
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import create_wf, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features
from ops.sampling import create_sampler
from ops.vmc import create_energy_fn, local_kinetic_energy, compute_potential_energy
from ops.pretraining import create_loss_and_sampler

  h5py.get_config().default_file_mode = 'a'


In [4]:
# randomness
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, num=3)

# system
n_walkers = 1024
n_el = 4
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
z_atoms = jnp.array([4.])

# ansatz

mol = Molecule(r_atoms, z_atoms, n_el, n_det=1)
walkers = mol.initialise_walkers(n_walkers=n_walkers)
wf, wf_orbitals = create_wf(mol)
vwf = vmap(wf, in_axes=(None, 0, 0))
sampler = create_sampler(wf, correlation_length=10)
params = initialise_params(subkeys[0], mol)
compute_energy = create_energy_fn(wf, r_atoms, z_atoms)
laplacian_jax = jit(vmap(local_kinetic_energy(wf), in_axes=(None, 0)))
loss_function, sampler = create_loss_and_sampler(mol, wf, wf_orbitals)

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype)

model_tc = fermiNet(mol_tc, n_det=1, n_sh=64, n_ph=16, diagonal=False)
model_tc = update_state_dict(model_tc, params)


System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
System: 
 Device  = cpu 
 dtype   = torch.float64 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
Model: 
 device   = cpu 
 n_sh     = 64 
 n_ph     = 16 
 n_layers = 2 
 n_det    = 1 



AssertionError: 

In [None]:

tree_structure(params)
flat, tree = tree_flatten(params)
print(params['s0'][0,0])

flat = [f * 0.1 for f in flat]
p2 = tree_unflatten(tree, flat)
print(p2['s0'][0,0])

In [None]:
lp = vwf(params, walkers)
lp_tc = model_tc(walkers_tc)
compare(lp_tc, lp)

ek_tc = laplacian(model_tc, walkers_tc)
ek_tc = -0.5 * (ek_tc[0].sum(-1) + ek_tc[1].sum(-1))
ek_jax = laplacian_jax(params, walkers)
compare(ek_tc, ek_jax)

e_jax = compute_energy(params, walkers)
e_tc = compute_local_energy(model_tc, walkers_tc, r_atoms_tc, z_atoms_tc)
compare(e_tc, e_jax)

In [None]:
loss_function(params, walkers)

In [None]:
vwf(params, walkers)

In [None]:
print(e_tc)

In [None]:
print(e_jax)

In [None]:
n, f, i, k =  2, 10, 3, 4
x = rnd.normal(key, (n, f))
y = rnd.normal(key, (f, i, k))
jnp.dot(x, y)
x @ y

In [None]:
n, f, i, k =  2, 10, 3, 4
x = rnd.normal(key, (f, n))
y = rnd.normal(key, (k, i, f))
z = jnp.dot(y, x)
print(z.shape)

In [None]:
jnp.prod(jnp.array((1, 2, 3)))

In [5]:
# potential energy

def batched_cdist_l2(x1, x2):
    x1_norm = (x1 ** 2).sum(-1, keepdim=True)
    x2_norm = (x2 ** 2).sum(-1, keepdim=True)
    cdist = (x2_norm.transpose(-1, -2) + x1_norm - 2 * x1 @ x2.transpose(-1, -2)).sqrt()
    return cdist

def pe_tc(r_atom: tc.Tensor, r_electron: tc.Tensor, z_atom: tc.Tensor) -> tc.Tensor:
    n_samples, n_electron = r_electron.shape[:2]
    n_atom = r_atom.shape[1]

    potential_energy = tc.zeros(n_samples)

    e_e_dist = batched_cdist_l2(r_electron, r_electron)  # electron - electron distances
    potential_energy += tc.tril(1. / e_e_dist, diagonal=-1).sum((-1, -2))

    a_e_dist = batched_cdist_l2(r_atom, r_electron)  # atom - electron distances
    potential_energy -= tc.einsum('a,bae->b', z_atom, 1./a_e_dist)

    return potential_energy

pe_tc = pe_tc(r_atoms_tc, walkers_tc, z_atoms_tc)

pe_jax = compute_potential_energy(walkers, r_atoms, z_atoms)

In [8]:
print(pe_tc.numpy() - pe_jax)

[0.00000000e+00 1.77635684e-15 0.00000000e+00 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00]
