In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [None]:
import numpy as np
from ops.utils import compare
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax, jit, vmap
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import *
from pytorch.pretraining import Pretrainer
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import create_wf, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features
from ops.sampling import create_sampler
from ops.vmc import create_energy_fn, local_kinetic_energy, compute_potential_energy
from ops.pretraining import create_loss_and_sampler

In [None]:
# randomness
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, num=3)

# system
n_walkers = 1024
n_el = 4
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
z_atoms = jnp.array([4.])

# ansatz

mol = Molecule(r_atoms, z_atoms, n_el, n_det=1)
walkers = mol.initialise_walkers(n_walkers=n_walkers)
wf, wf_orbitals = create_wf(mol)
vwf = vmap(wf, in_axes=(None, 0, 0))
sampler = create_sampler(wf, correlation_length=10)
params = initialise_params(subkeys[0], mol)
compute_energy = create_energy_fn(wf, r_atoms, z_atoms)
laplacian_jax = jit(vmap(local_kinetic_energy(wf), in_axes=(None, 0)))
loss_function, sampler = create_loss_and_sampler(mol, wf, wf_orbitals)

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype)

model_tc = fermiNet(mol_tc, n_det=1, n_sh=64, n_ph=16, diagonal=False)
model_tc = update_state_dict(model_tc, params)
