In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import numpy as np
import os

module_path = os.path.abspath(os.path.join(os.getcwd() + "/.."))
if module_path not in sys.path:
    sys.path.append(module_path)

from ad_afqmc_prototype import config

config.setup_jax()

from jax import numpy as jnp

from pyscf import gto, scf, mcscf, cc, ao2mo

np.set_printoptions(precision=5, suppress=True)

In [None]:
r = 1.0
mol = gto.M(
    atom=f"H 0 0 0; H 0 0 {1.0*r}; H 0 0 {2.0*r}; H 0 0 {3.0*r}",
    basis="sto-6g",
    verbose=3,
)
mf = scf.RHF(mol)
mf.kernel()

mc = mcscf.CASCI(mf, mol.nao, mol.nelectron)
mc.fix_spin_(ss=0)
mc.kernel()

mycc = cc.CCSD(mf)
mycc.kernel()
et = mycc.ccsd_t()
print("CCSD(T) energy: ", mycc.e_tot + et)

h0 = mf.energy_nuc()
h1 = np.array(mf.mo_coeff.T @ mf.get_hcore() @ mf.mo_coeff)
eri = np.array(ao2mo.kernel(mol, mf.mo_coeff))
eri = ao2mo.restore(4, eri, mol.nao)

from ad_afqmc_prototype import integrals
chol = integrals.modified_cholesky(eri, max_error=1e-6)

In [None]:
from ad_afqmc_prototype.core.system import system
from ad_afqmc_prototype.ham.chol import ham_chol
from ad_afqmc_prototype.trial import rhf as rhf_trial

# data
sys = system(norb=mol.nao, nelec=mol.nelec, walker_kind="restricted")
ham_data = ham_chol(h0=jnp.array(h0), h1=jnp.array(h1), chol=jnp.array(chol))
rhf_trial_data = rhf_trial.rhf_trial(jnp.eye(mol.nao, mol.nelectron // 2))

# trial and measurement operations
from ad_afqmc_prototype.core.ops import trial_ops, meas_ops
from ad_afqmc_prototype.meas import rhf as rhf_meas

rhf_trial_ops = trial_ops(rhf_trial.overlap_r, rhf_trial.get_rdm1)
rhf_meas_ops = meas_ops(
    rhf_meas.overlap_r,
    build_meas_ctx=rhf_meas.build_meas_ctx,
    kernels={
        "force_bias": rhf_meas.force_bias_kernel_r,
        "energy": rhf_meas.energy_kernel_r,
    },
)
# or using factory functions
# rhf_trial_ops = rhf_trial.make_rhf_trial_ops(sys=sys)
# rhf_meas_ops = rhf_meas.make_rhf_meas_ops(sys=sys)


# propagation operations
from ad_afqmc_prototype.prop.chol_afqmc_ops import make_chol_afqmc_ops
from ad_afqmc_prototype.prop.blocks import afqmc_block
from ad_afqmc_prototype.prop.types import afqmc_params

afqmc_prop_ops = make_chol_afqmc_ops(ham_data, sys.walker_kind)
params = afqmc_params(
    n_eql_blocks=20, n_blocks=200, seed=np.random.randint(0, int(1e6))
)

# driver
from ad_afqmc_prototype.driver import run_afqmc_energy

mean, err, block_e_all, block_w_all = run_afqmc_energy(
    sys=sys,
    params=params,
    ham_data=ham_data,
    trial_ops=rhf_trial_ops,
    trial_data=rhf_trial_data,
    meas_ops=rhf_meas_ops,
    prop_ops=afqmc_prop_ops,
    block_fn=afqmc_block,
)

# or simply
from ad_afqmc_prototype import default
afqmc = default.Rhf(mf)
afqmc.params = params # to have the exact same results
mean, err, block_e_all, block_w_all = afqmc.kernel()

In [None]:
from matplotlib import pyplot as plt

plt.plot(block_e_all, "o-")