In [1]:
import sys
import numpy as np
from functools import reduce
from pyscf.lib import logger
from ad_afqmc.lno.cc import LNOCCSD
from ad_afqmc import lno_ccsd
from pyscf import gto, scf, mp, cc

log = logger.Logger(sys.stdout, 6)

a = 0.9
nH = 6
atoms = ""
for i in range(nH):
    atoms += f"H {i*a} 0 0 \n"

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mf = scf.RHF(mol).density_fit()
mf.kernel()

frozen = 0
mmp = mp.MP2(mf,frozen=frozen)
mmp.kernel()[0]

# cc
mycc = cc.CCSD(mf)
mycc.kernel()
et = mycc.ccsd_t()

# fci
#cisolver = fci.FCI(mf)
#fci_ene, fci_vec = cisolver.kernel()

print(f'rhf energy is {mf.e_tot}')
print(f"ccsd energy is {mycc.e_tot}")
print(f"ccsd_t energy is {mycc.e_tot+et}")
print(f"ccsd correlation energy is {mycc.e_corr}")
#print(f"fci_ene: {fci_ene}", flush=True)

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#1882-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Mon Jun  2 15:54:18 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 6
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.

In [2]:
import jax
import jax.numpy as jnp
from jax import jvp, lax
from ad_afqmc import wavefunctions
from ad_afqmc import sampling
from jax import vmap, random
from typing import Tuple

In [44]:
def _calc_hf_olp(walker):
    nocc = walker.shape[1]
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    return o0

def calc_hf_olp(walkers):
    hf_olp = vmap(_calc_hf_olp, in_axes=(0))(walkers)
    return hf_olp

def _calc_olp_ratio_restricted(walker: jax.Array, wave_data: dict) -> complex:
    '''
    <psi_hf|walker>/<psi_ccsd|walker>
    '''
    nocc, ci1, ci2 = walker.shape[1], wave_data["ci1"], wave_data["ci2"]
    GF = (walker.dot(jnp.linalg.inv(walker[: walker.shape[1], :]))).T
    #o0 = jnp.linalg.det(walker[: walker.shape[1], :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return 1/(1.0 + 2 * o1 + o2)

def frg_hf_orb_cr(walkers,ham_data,wave_data,trial):
    hf_orb_cr = vmap(_frg_hf_orb_cr, in_axes=(None, None, 0, None, None))(
        ham_data['rot_h1'], ham_data['rot_chol'], walkers, trial, wave_data)
    return hf_orb_cr

def _frg_hf_orb_cr(rot_h1, rot_chol, walker, trial, wave_data):
    m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc = rot_h1.shape[0]
    _calc_green = wavefunctions.rhf(
        trial.norb,trial.nelec,n_batch=trial.n_batch)._calc_green
    green_walker = _calc_green(walker, wave_data)
    f = jnp.einsum('gij,jk->gik', rot_chol[:,:nocc,nocc:],
                    green_walker.T[nocc:,:nocc], optimize='optimal')
    c = vmap(jnp.trace)(f)
    eneo2Jt = jnp.einsum('Gxk,xk,G->',f,m,c)*2 
    eneo2ext = jnp.einsum('Gxy,Gyk,xk->',f,f,m)
    hf_orb_en = eneo2Jt - eneo2ext
    olp_ratio = _calc_olp_ratio_restricted(walker,wave_data)
    hf_orb_cr = jnp.real(olp_ratio*hf_orb_en)
    return hf_orb_cr

def _modified_ccsd_olp_restricted(walker: jax.Array, wave_data: dict) -> complex:
    '''
    <psi_ccsd|walker>=<psi_0|walker>+C_ia^*G_ia+C_iajb^*(G_iaG_jb-G_ibG_ja)
    modified CCSD overlap returns the second and the third term
    that is, the overlap of the walker with the CCSD wavefunction
    without the hartree-fock part
    and skip one sum over the occ
    '''
    #prjlo = wave_data["prjlo"].reshape(walker.shape[1])
    #pick_i = jnp.where(abs(prjlo) > 1e-6, 1, 0)
    # m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc, ci1, ci2 = walker.shape[1], wave_data["ci1"], wave_data["ci2"]
    gf = (walker.dot(jnp.linalg.inv(walker[: walker.shape[1], :]))).T
    o0 = jnp.linalg.det(walker[: walker.shape[1], :]) ** 2
    #o1 = jnp.einsum("ia,ja,ij->", ci1, gf[:, nocc:],m)
    o1 = jnp.einsum("ia,ja->", ci1, gf[:, nocc:])
    # o2 = 2 * jnp.einsum(
    #     "iajb,ia,jb,ij->", ci2, gf[:, nocc:], gf[:, nocc:],m
    # ) - jnp.einsum("iajb,ib,ja,ij->", ci2, gf[:, nocc:], gf[:, nocc:],m)
    o2 = 2 * jnp.einsum("iajb,ia,jb->", ci2, gf[:, nocc:], gf[:, nocc:]) \
        - jnp.einsum("iajb,ib,ja->", ci2, gf[:, nocc:], gf[:, nocc:])
    olp = (2*o1+o2)*o0
    #olp_i = jnp.einsum("i,i->", olp, pick_i)
    return olp

def modified_ccsd_olp_restricted(walkers: jax.Array, wave_data: dict, trial) -> jax.Array:
    n_batch = trial.n_batch
    norb = trial.norb
    n_walkers = walkers.shape[0]
    #nocc = walkers.shape[2]
    batch_size = n_walkers // n_batch

    def scanned_fun(carry, walker_batch):
        overlap_batch = vmap(_modified_ccsd_olp_restricted, in_axes=(0, None))(
            walker_batch, wave_data
        )
        return carry, overlap_batch

    _, overlaps = lax.scan(
        scanned_fun, None, walkers.reshape(n_batch, batch_size, norb, -1)
    )
    return overlaps.reshape(n_walkers)

def _frg_modified_ccsd_olp_restricted(walker: jax.Array, wave_data: dict) -> complex:
    '''
    <psi_ccsd|walker>=<psi_0|walker>+C_ia^*G_ia+C_iajb^*(G_iaG_jb-G_ibG_ja)
    modified CCSD overlap returns the second and the third term
    that is, the overlap of the walker with the CCSD wavefunction
    without the hartree-fock part
    and skip one sum over the occ
    '''
    prjlo = wave_data["prjlo"].reshape(walker.shape[1])
    pick_i = jnp.where(abs(prjlo) > 1e-6, 1, 0)
    # m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc, ci1, ci2 = walker.shape[1], wave_data["ci1"], wave_data["ci2"]
    gf = (walker.dot(jnp.linalg.inv(walker[: walker.shape[1], :]))).T
    o0 = jnp.linalg.det(walker[: walker.shape[1], :]) ** 2
    #o1 = jnp.einsum("ia,ja,ij->", ci1, gf[:, nocc:],m)
    o1 = jnp.einsum("ia,ja->i", ci1, gf[:, nocc:])
    # o2 = 2 * jnp.einsum(
    #     "iajb,ia,jb,ij->", ci2, gf[:, nocc:], gf[:, nocc:],m
    # ) - jnp.einsum("iajb,ib,ja,ij->", ci2, gf[:, nocc:], gf[:, nocc:],m)
    o2 = 2 * jnp.einsum("iajb,ia,jb->i", ci2, gf[:, nocc:], gf[:, nocc:]) \
        - jnp.einsum("iajb,ib,ja->i", ci2, gf[:, nocc:], gf[:, nocc:])
    olp = (2*o1+o2)*o0
    olp_i = jnp.einsum("i,i->", olp, pick_i)
    return olp_i

def frg_modified_ccsd_olp_restricted(walkers: jax.Array, wave_data: dict, trial) -> jax.Array:
    n_batch = trial.n_batch
    norb = trial.norb
    n_walkers = walkers.shape[0]
    #nocc = walkers.shape[2]
    batch_size = n_walkers // n_batch

    def scanned_fun(carry, walker_batch):
        overlap_batch = vmap(_frg_modified_ccsd_olp_restricted, in_axes=(0, None))(
            walker_batch, wave_data
        )
        return carry, overlap_batch

    _, overlaps = lax.scan(
        scanned_fun, None, walkers.reshape(n_batch, batch_size, norb, -1)
    )
    return overlaps.reshape(n_walkers)

def _olp_exp1(x: float, h1_mod: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    <psi_ccsd|exp(x*h1_mod)|walker>/<psi_ccsd|walker> without the hf part
    '''
    walker_1x = (
            walker
            + x * h1_mod.dot(walker)
            # + x**2 / 2.0 * h1_mod.dot(h1_mod.dot(walker))
        )
    #lno_ccsd._thouless_linear(x*h1_mod, walker)
    olp = _modified_ccsd_olp_restricted(walker_1x, wave_data)
    return olp

def _frg_olp_exp1(x: float, h1_mod: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    <psi_ccsd|exp(x*h1_mod)|walker>/<psi_ccsd|walker> without the hf part
    '''
    walker_1x = (
            walker
            + x * h1_mod.dot(walker)
            # + x**2 / 2.0 * h1_mod.dot(h1_mod.dot(walker))
        )
    #lno_ccsd._thouless_linear(x*h1_mod, walker)
    olp = _frg_modified_ccsd_olp_restricted(walker_1x, wave_data)
    return olp

def _olp_exp2(x: float, chol_i: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    <psi_ccsd|exp(x*h2_mod)|walker>/<psi_ccsd|walker> without the hf part
    '''
    walker_2x = (
            walker
            + x * chol_i.dot(walker)
            + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker))
            #+ x**3 / 6.0 * chol_i.dot(chol_i.dot(chol_i.dot(walker)))
        )
    olp = _modified_ccsd_olp_restricted(walker_2x, wave_data)
    return olp

def _frg_olp_exp2(x: float, chol_i: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    <psi_ccsd|exp(x*h2_mod)|walker>/<psi_ccsd|walker> without the hf part
    '''
    walker_2x = (
            walker
            + x * chol_i.dot(walker)
            + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker))
            #+ x**3 / 6.0 * chol_i.dot(chol_i.dot(chol_i.dot(walker)))
        )
    olp = _frg_modified_ccsd_olp_restricted(walker_2x, wave_data)
    return olp

def _mod_ccsd_cr(
    walker: jax.Array,
    ham_data: dict,
    wave_data: dict,
    trial: wavefunctions,
    eps :float = 1e-5,
):
    """Calculates local energy using AD and finite difference for the two body term"""

    norb = trial.norb
    chol = ham_data["chol"].reshape(-1, norb, norb)
    h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
    # v1 the one-body energy from the reordering of the 
    # two-body operators into non-normal ordered form
    v0 = 0.5 * jnp.einsum("gik,gjk->ij",chol,chol,optimize="optimal")
    h1_mod = h1 - v0
    ccsd_olp = trial._calc_overlap_restricted(walker, wave_data)

    # one body
    x = 0.0
    f1 = lambda a: _olp_exp1(a,h1_mod,walker,wave_data)
    _, d_overlap = jvp(f1, [x], [1.0])
    #d_overlap = (f1(eps)-f1(0.0))/eps

    # two body

    # carry: [eps, walker, wave_data]
    def scanned_fun(carry, x):
        eps, walker, wave_data = carry
        return carry, _olp_exp2(eps,x,walker,wave_data)

    _, overlap_p = lax.scan(scanned_fun, (eps, walker, wave_data), chol)
    _, overlap_0 = lax.scan(scanned_fun, (0.0, walker, wave_data), chol)
    _, overlap_m = lax.scan(scanned_fun, (-1.0 * eps, walker, wave_data), chol)
    #print(overlap_p,overlap_0,overlap_m)
    #print(eps)
    d_2_overlap = (overlap_p - 2.0 * overlap_0 + overlap_m) / eps / eps

    ccsd_cr = jnp.real((d_overlap + jnp.sum(d_2_overlap) / 2.0) / ccsd_olp)
    return ccsd_cr

def _frg_ccsd_orb_cr(
    walker: jax.Array,
    ham_data: dict,
    wave_data: dict,
    trial: wavefunctions,
    eps :float = 1e-5,
):
    """Calculates local energy using AD and finite difference for the two body term"""

    norb = trial.norb
    chol = ham_data["chol"].reshape(-1, norb, norb)
    h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
    # v1 the one-body energy from the reordering of the 
    # two-body operators into non-normal ordered form
    v0 = 0.5 * jnp.einsum("gik,gjk->ij",chol,chol,optimize="optimal")
    h1_mod = h1 - v0
    ccsd_olp = trial._calc_overlap_restricted(walker, wave_data)

    x = 0.0
    # one body
    f1 = lambda a: _frg_olp_exp1(a,h1_mod,walker,wave_data)
    _, d_overlap = jvp(f1, [x], [1.0])

    # two body

    # carry: [eps, walker, wave_data]
    def scanned_fun(carry, x):
        eps, walker, wave_data = carry
        return carry, _frg_olp_exp2(eps,x,walker,wave_data)

    _, overlap_p = lax.scan(scanned_fun, (eps, walker, wave_data), chol)
    _, overlap_0 = lax.scan(scanned_fun, (0.0, walker, wave_data), chol)
    _, overlap_m = lax.scan(scanned_fun, (-1.0 * eps, walker, wave_data), chol)
    #print(overlap_p,overlap_0,overlap_m)
    #print(eps)
    d_2_overlap = (overlap_p - 2.0 * overlap_0 + overlap_m) / eps / eps

    ccsd_cr = jnp.real((d_overlap + jnp.sum(d_2_overlap) / 2.0) / ccsd_olp)

    return ccsd_cr

def mod_ccsd_cr(
        walkers: jax.Array, 
        ham_data: dict, 
        wave_data: dict, 
        trial: wavefunctions,
        eps: float= 1e-5) -> jax.Array:
    n_walkers = walkers.shape[0]
    #nocc = walkers.shape[2]
    batch_size = n_walkers // trial.n_batch

    def scanned_fun(carry, walker_batch):
        energy_batch = vmap(_mod_ccsd_cr, in_axes=(0, None, None, None, None))(
            walker_batch, ham_data, wave_data, trial, eps
        )
        return carry, energy_batch

    _, energies = lax.scan(
        scanned_fun,
        None,
        walkers.reshape(trial.n_batch, batch_size, trial.norb, -1),
    )
    return energies.reshape(n_walkers)


def frg_ccsd_orb_cr(walkers: jax.Array,
                ham_data: dict,
                wave_data: dict,
                trial: wavefunctions,
                eps: float = 1e-5) -> jax.Array:
    n_walkers = walkers.shape[0]
    #nocc = walkers.shape[2]
    batch_size = n_walkers // trial.n_batch

    def scanned_fun(carry, walker_batch):
        energy_batch = vmap(_frg_ccsd_orb_cr, in_axes=(0, None, None, None, None))(
            walker_batch, ham_data, wave_data, trial, eps
        )
        return carry, energy_batch

    _, energies = lax.scan(
        scanned_fun,
        None,
        walkers.reshape(trial.n_batch, batch_size, trial.norb, -1),
    )
    return energies.reshape(n_walkers)

sampler_eq = sampling.sampler(n_prop_steps=10, n_ene_blocks=5, n_sr_blocks=10)

def block(prop_data,ham_data,prop,trial,wave_data,sampler):
        """Block scan function. Propagation and calculate total ccsd correction energy."""
        prop_data["key"], subkey = random.split(prop_data["key"])
        fields = random.normal(
            subkey,
            shape=(
                sampler.n_prop_steps,
                prop.n_walkers,
                ham_data["chol"].shape[0],
            ),
        )
        _step_scan_wrapper = lambda x, y: sampler._step_scan(
            x, y, ham_data, prop, trial, wave_data
        )
        prop_data, _ = lax.scan(_step_scan_wrapper, prop_data, fields)
        prop_data["n_killed_walkers"] += prop_data["weights"].size - jnp.count_nonzero(
            prop_data["weights"]
        )
        prop_data = prop.orthonormalize_walkers(prop_data)
        prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)

        ccsd_cr = mod_ccsd_cr(
            prop_data["walkers"],ham_data,wave_data,trial,1e-5)
        energy_samples = jnp.real(
            trial.calc_energy(prop_data["walkers"], ham_data, wave_data)
        )
        energy_samples = jnp.where(
            jnp.abs(energy_samples - prop_data["e_estimate"]) > jnp.sqrt(2.0 / prop.dt),
            prop_data["e_estimate"],
            energy_samples,
        )
        blk_wt = jnp.sum(prop_data["weights"])
        blk_ccsd_cr = jnp.sum(ccsd_cr * prop_data["weights"]) / blk_wt

        blk_energy = jnp.sum(energy_samples * prop_data["weights"]) / blk_wt
        prop_data["pop_control_ene_shift"] = (
            0.9 * prop_data["pop_control_ene_shift"] + 0.1 * blk_energy
        )
        
        return prop_data,(blk_ccsd_cr,blk_wt)

def block_orb(prop_data,ham_data,prop,trial,wave_data,sampler):
        """Block scan function. Propagation and orbital_i energy calculation."""
        prop_data["key"], subkey = random.split(prop_data["key"])
        fields = random.normal(
            subkey,
            shape=(
                sampler.n_prop_steps,
                prop.n_walkers,
                ham_data["chol"].shape[0],
            ),
        )
        _step_scan_wrapper = lambda x, y: sampler._step_scan(
            x, y, ham_data, prop, trial, wave_data
        )
        prop_data, _ = lax.scan(_step_scan_wrapper, prop_data, fields)
        prop_data["n_killed_walkers"] += prop_data["weights"].size - jnp.count_nonzero(
            prop_data["weights"]
        )
        prop_data = prop.orthonormalize_walkers(prop_data)
        prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
        ### hartree-fock orbital energy ###
        hf_orb_cr = frg_hf_orb_cr(prop_data["walkers"],ham_data,wave_data,trial)
        #hf_olp = calc_hf_olp(prop_data["walkers"])
        olp_ratio = lno_ccsd.cal_olp_ratio(prop_data["walkers"], wave_data,trial)
        #hf_olp = calc_hf_olp(prop_data["walkers"])
        #ccsd_mod_olp = frg_modified_ccsd_olp_restricted(prop_data["walkers"], wave_data, trial)
        ccsd_orb_cr = frg_ccsd_orb_cr(prop_data["walkers"], ham_data, 
                                        wave_data, trial,1e-5)
        #orb_energy_samples = jnp.real(hf_orb_energy*olp_ratio + ccsd_orb_correction)
        #hf_correction = -(1-olp_ratio)*hf_elec_frg
        #orb_energy_samples = jnp.real(
        #        sum(prop_data["weights"]*orb_energy)/sum(prop_data["weights"]))
        #orb_energy_samples = jnp.real(hf_correction + olp_ratio*hf_orb_energy + ccsd_orb_correction)
        energy_samples = jnp.real(
            trial.calc_energy(prop_data["walkers"], ham_data, wave_data)
        )
        energy_samples = jnp.where(
            jnp.abs(energy_samples - prop_data["e_estimate"]) > jnp.sqrt(2.0 / prop.dt),
            prop_data["e_estimate"],
            energy_samples,
        )
        blk_wt = jnp.sum(prop_data["weights"])
        # print(orb_energy_samples)
        blk_hf_orb_cr = jnp.sum(hf_orb_cr * prop_data["weights"]) / blk_wt
        blk_ccsd_orb_cr = jnp.sum(ccsd_orb_cr * prop_data["weights"]) / blk_wt
        #blk_hf_olp = jnp.sum(hf_olp * prop_data["weights"]) / blk_wt
        #blk_ccsd_mod_olp = jnp.sum(ccsd_mod_olp * prop_data["weights"]) / blk_wt
        blk_olp_ratio = jnp.sum(olp_ratio * prop_data["weights"]) / blk_wt
        # block_orb_energy = jnp.sum(orb_energy_samples * prop_data["weights"]) / block_weight
        # block_olp_ratio = jnp.sum(olp_ratio * prop_data["weights"]) / block_weight
        blk_energy = jnp.sum(energy_samples * prop_data["weights"]) / blk_wt
        prop_data["pop_control_ene_shift"] = (
            0.9 * prop_data["pop_control_ene_shift"] + 0.1 * blk_energy
        )
        
        return prop_data,(blk_hf_orb_cr,blk_ccsd_orb_cr,blk_olp_ratio,blk_wt)


# @partial(jit, static_argnums=(1, 3, 5))
def propagate_phaseless_orb(
    ham_data: dict,
    prop,
    prop_data: dict,
    trial: wavefunctions,
    wave_data: dict,
    sampler,
) -> Tuple[jax.Array, dict]:
    def _sr_block_scan_wrapper(x,_):
        return _sr_block_scan_orb(x, ham_data, prop, trial, wave_data)

    prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
    prop_data["n_killed_walkers"] = 0
    prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]
    prop_data,(blk_hf_orb_en,blk_ccsd_orb_cr,blk_olp_ratio,blk_wt) \
        = lax.scan(
        _sr_block_scan_wrapper, prop_data, xs=None, length=sampler.n_sr_blocks
        )
    prop_data["n_killed_walkers"] /= (
        sampler.n_sr_blocks * sampler.n_ene_blocks * prop.n_walkers
    )
    #orb_energy = jnp.sum(block_orb_energy * block_weight) / jnp.sum(block_weight)
    #olp_ratio = jnp.sum(block_olp_ratio * block_weight) / jnp.sum(block_weight)
    hf_orb_en = jnp.sum(blk_hf_orb_en*blk_wt)/jnp.sum(blk_wt)
    ccsd_orb_cr = jnp.sum(blk_ccsd_orb_cr*blk_wt)/jnp.sum(blk_wt)
    #hf_olp = jnp.sum(blk_hf_olp*blk_wt)/jnp.sum(blk_wt)
    #ccsd_mod_olp = jnp.sum(blk_ccsd_mod_olp*blk_wt)/jnp.sum(blk_wt)
    olp_ratio = jnp.sum(blk_olp_ratio*blk_wt)/jnp.sum(blk_wt)
    return (hf_orb_en,ccsd_orb_cr,olp_ratio), prop_data

# @partial(jit, static_argnums=( 5))
def _sr_block_scan_orb(
    prop_data: dict,
    ham_data: dict,
    prop,
    trial: wavefunctions,
    wave_data: dict,
    sampler,
) -> Tuple[dict, Tuple[jax.Array, jax.Array]]:
    def _block_scan_wrapper(x,_):
        return block_orb(x, ham_data, prop, trial, wave_data)
    
    prop_data, (blk_hf_orb_en,blk_ccsd_orb_cr,blk_olp_ratio,blk_wt) \
        = lax.scan(
            _block_scan_wrapper, prop_data, xs=None, length=sampler.n_ene_blocks
            )
    prop_data = prop.stochastic_reconfiguration_local(prop_data)
    prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
    return prop_data,(blk_hf_orb_en,blk_ccsd_orb_cr,blk_olp_ratio,blk_wt)

In [3]:
# first build the local active space and run lno-ccsd
# then use the lno-ccsd wavefunction as the reference 
# of the lno-afqmc walker

In [4]:
options = {'n_eql': 2,
           'n_prop_steps': 20,
            'n_ene_blocks': 1,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 50,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.01,
            'ad_mode':None,
            }

In [12]:
### build full wavefunction ###
lno_ccsd.prep_lno_amp_chol_file(mycc,mf.mo_coeff,mol.nao,mol.nelectron,norb_frozen=[],
    t1=mycc.t1,t2=mycc.t2,chol_cut=1e-6,mo_file="mo_full",amp_file="amp_full",chol_file="chol_full")

ham_data_f, ham_f, prop_f, trial_f, wave_data_f, sampler_f, observable_f, options_f, _ = (
lno_ccsd.prep_lnoccsd_afqmc(
    options,prjlo=[],mo_file="mo_full.npz",amp_file="amp_full.npz",chol_file="chol_full"))

TypeError: prep_lno_amp_chol_file() missing 1 required positional argument: 'nelec_act'

In [None]:
seed = options_f["seed"]
init_walkers = None
ham_data_f = wavefunctions.rhf(trial_f.norb, trial_f.nelec,n_batch=trial_f.n_batch
                            )._build_measurement_intermediates(ham_data_f, wave_data_f)
ham_data_f = ham_f.build_measurement_intermediates(ham_data_f, trial_f, wave_data_f)
ham_data_f = ham_f.build_propagation_intermediates(ham_data_f, prop_f, trial_f, wave_data_f)

prop_data_f = prop_f.init_prop_data(trial_f, wave_data_f, ham_data_f, init_walkers)
if jnp.abs(jnp.sum(prop_data_f["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data_f["key"] = random.PRNGKey(seed)

prop_data_f["overlaps"] = trial_f.calc_overlap(prop_data_f["walkers"], wave_data_f)
prop_data_f["n_killed_walkers"] = 0
prop_data_f["pop_control_ene_shift"] = prop_data_f["e_estimate"]
#e_elec = lno_ccsd.e_elec_frg(mf,orbfragloc,orbfrag,frzfrag)


prop_data_f,(ccsd_tot_cr,_) \
    = lno_ccsd.propagate_phaseless_tot(ham_data_f,prop_f,prop_data_f,trial_f,wave_data_f,sampler_f)
    #= lno_ccsd.block_tot(prop_data_f,ham_data_f,prop_f,trial_f,wave_data_f,sampler_f)
print(ccsd_tot_cr)

-0.20271197449273853


In [5]:
from ad_afqmc.lno.base import lno
from ad_afqmc.lno.afqmc import LNOAFQMC
import os
from ad_afqmc import config
config.setup_jax()
MPI = config.setup_comm()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

thresh = 1e-4
nproc = 4
chol_cut = 1e-6
_fdot = np.dot
fdot = lambda *args: reduce(_fdot, args)

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh
lno_cc.thresh_vir = thresh
lno_cc.lo_type = 'pm'
lno_cc.no_type = 'cim'
lno_cc.frag_lolist = '1o'
lno_cc.ccsd_t = True
lno_cc.force_outcore_ao2mo = True
orbloc = lno_cc.orbloc
lo_type = lno_cc.lo_type
no_type = lno_cc.no_type
frag_atmlist = lno_cc.frag_atmlist
frag_lolist = lno_cc.frag_lolist
s1e = lno_cc._scf.get_ovlp()

lno_qmc = LNOAFQMC(mf, thresh=thresh)
lno_qmc.thresh_occ = thresh
lno_qmc.thresh_vir = thresh
lno_qmc.nblocks = options["n_blocks"]
lno_qmc.nwalk_per_proc = options["n_walkers"]
lno_qmc.nproc = nproc
lno_qmc.lo_type = 'pm'
lno_qmc.no_type = 'cim'
lno_qmc.frag_lolist = '1o'
lno_qmc.chol_cut = chol_cut

# NO type
no_type = 'ie'
frag_lolist = '1o'
log.info('no_type = %s', no_type)

# LO construction
orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
orbactocc = lno_cc.split_mo()[1] # non-localized active occ
m = fdot(orbloc.T, s1e, orbactocc)
lospanerr = abs(fdot(m.T, m) - np.eye(m.shape[1])).max()
if lospanerr > 1e-10:
    log.error('LOs do not fully span the occupied space! '
                'Max|<occ|LO><LO|occ>| = %e', lospanerr)
    raise RuntimeError

# check 2: Span(LO) == Span(occ)
occspanerr = abs(fdot(m, m.T) - np.eye(m.shape[0])).max()
if occspanerr < 1e-10:
    log.info('LOs span exactly the occupied space.')
    if no_type not in ['ir','ie']:
        log.error('"no_type" must be "ir" or "ie".')
        raise ValueError
else:
    log.info('LOs span occupied space plus some virtual space.')

# LO assignment to fragments

if frag_lolist == '1o':
    log.info('Using single-LO fragment') # this is what we use, every active local occ stands for a fragment
    frag_lolist = [[i] for i in range(orbloc.shape[1])]
else: print('Only support single LO fragment!')
nfrag = len(frag_lolist)
frag_nonvlist = lno_cc.frag_nonvlist

# dump info
log.info('nfrag = %d  nlo = %d', nfrag, orbloc.shape[1])
log.info('frag_atmlist = %s', frag_atmlist)
log.info('frag_lolist = %s', frag_lolist)
log.info('frag_nonvlist = %s', frag_nonvlist)

if not (no_type[0] in 'rei' and no_type[1] in 'rei'):
    log.warn('Input no_type "%s" is invalid.', no_type)
    raise ValueError

if frag_nonvlist is None: frag_nonvlist = [[None,None]] * nfrag

eris = lno_cc.ao2mo()
ecorr_ccsd = np.empty(nfrag)
frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = thresh
for ifrag in range(1):
    print(f'running fragment {ifrag+1}')
    #if(len(lno_cc.runfrags)>0):
    #    if(ifrag not in lno_cc.runfrags):frag_res[ifrag] = (0,0,0)
    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, _ = lno.make_fpno1(lno_cc, eris, orbfragloc, no_type,
                                                THRESH_INTERNAL, thresh_pno,
                                                frozen_mask=frozen_mask,
                                                frag_target_nocc=frag_target_nocc,
                                                frag_target_nvir=frag_target_nvir,
                                                canonicalize=False)

    #print('lo projection on its fragment subspace: ',fdot(orbfragloc.T, s1e, orbfrag))
    #print('frozen orbitals: ',frzfrag)

    ecorr_ccsd[ifrag],_,t1,t2 = lno_ccsd.cc_impurity_solve(
        mf,orbfrag,orbfragloc,frozen=frzfrag,eris=eris,log=log)
    print(f'ccsd correlation energy for fragment {ifrag+1}: {ecorr_ccsd[ifrag]}')

    frozen=frzfrag

    maskocc = mf.mo_occ>1e-10
    nmo = mf.mo_occ.size

    # Convert frozen to 0 bc PySCF solvers do not support frozen=None or empty list
    if frozen is None:
        frozen = 0
    elif isinstance(frozen, (list,tuple,np.ndarray)) and len(frozen) == 0:
        frozen = 0

    if isinstance(frozen, (int,np.integer)):
        maskact = np.hstack([np.zeros(frozen,dtype=bool),
                                np.ones(nmo-frozen,dtype=bool)])
    elif isinstance(frozen, (list,tuple,np.ndarray)):
        maskact = np.array([i not in frozen for i in range(nmo)])
    else:
        raise RuntimeError

    orbfrzocc = orbfrag[:,~maskact& maskocc]
    orbactocc = orbfrag[:, maskact& maskocc]
    orbactvir = orbfrag[:, maskact&~maskocc]
    orbfrzvir = orbfrag[:,~maskact&~maskocc]
    nfrzocc, nactocc, nactvir, nfrzvir = [orb.shape[1]
                                            for orb in [orbfrzocc,orbactocc,
                                                        orbactvir,orbfrzvir]]
    s1e = mf.get_ovlp() if eris is None else eris.s1e
    prjlo = fdot(orbfragloc.T, s1e, orbactocc) ### overlap between the lo and each active occ in its fragment
    #print('# lo projection on its active occupied fragment subspace', prjlo)

    lno_ccsd.prep_lno_amp_chol_file(
        mycc,orbfrag,options,
        norb_act=(nactocc+nactvir),nelec_act=nactocc*2,prjlo=prjlo,norb_frozen=frzfrag,
        t1=t1,t2=t2,mo_file="mo_test.npz",amp_file="amp_test.npz",chol_file="chol_test"
                )

    # mpi_prefix = "mpirun "
    # if nproc is not None:
    #     mpi_prefix += f"-np {nproc} "
    # script = "/home/yichi/research/software/cs_afqmc/ad_afqmc/lnocc/hchain/cc/run_lnocc_frg.py"
    # os.system(
    #     f"export OMP_NUM_THREADS=1; export MKL_NUM_THREADS=1; {mpi_prefix} python {script} |tee lno_5_ccsd_afqmc_frg_{ifrag+1}.out"
    # )
    
    ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
        lno_ccsd.prep_lnoccsd_afqmc(options,mo_file="mo_test.npz",
                                    amp_file="amp_test.npz",chol_file="chol_test"))

    seed = options["seed"]
    propagator = prop
    init_walkers = None
    ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                 )._build_measurement_intermediates(ham_data, wave_data)
    ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
    ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

    prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
    if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
        raise ValueError(
            "Initial overlaps are zero. Pass walkers with non-zero overlap."
        )
    prop_data["key"] = random.PRNGKey(seed + rank)

    prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
    prop_data["n_killed_walkers"] = 0
    prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

    # prop_data,(hf_orb_cr,ccsd_orb_cr,olp_ratio,wt) \
    #         = lno_ccsd.propagate_phaseless_orb(ham_data,prop,prop_data,trial,wave_data,sampler)
    print('################ run equalibration ################')
    print(' energy \t   orb_cr')
    for n in range(10):
        prop_data, (energy,orb_cr,_) \
            = lno_ccsd.propagate_phaseless_orb(ham_data,prop,prop_data,trial,wave_data,sampler)
        print(f'{energy:.6f} \t {orb_cr:.6f}')

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
no_type = ie
lo_type = pm


LOs span exactly the occupied space.
Using single-LO fragment
nfrag = 3  nlo = 3
frag_atmlist = None
frag_lolist = [[0], [1], [2]]
frag_nonvlist = None
Lov is saved to /tmp/ymo5w1jj
running fragment 1
    impsol:  1 LOs  13/30 MOs  3 occ  10 vir
    CPU time for imp sol - eri          0.00 sec, wall time      0.01 sec
    CPU time for imp sol - mp2 amp      0.00 sec, wall time      0.00 sec
    CPU time for imp sol - mp2 ene      0.02 sec, wall time      0.00 sec
    CPU time for imp sol - cc  amp      0.39 sec, wall time      0.40 sec
    CPU time for imp sol - cc  ene      0.00 sec, wall time      0.00 sec
ccsd correlation energy for fragment 1: -0.04172052265982682
# Generating Cholesky Integrals
# frozen orbitals are [13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29]
# local active orbitals are [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
# local active space size 13
# loc_eris shape: (169, 169)
# chol shape: (75, 169)
# Finished calculating Cholesky integrals

# Size of the correlat

In [6]:
for n in range(10):
    prop_data, (energy,orb_cr,_) \
        = lno_ccsd.propagate_phaseless_orb(ham_data,prop,prop_data,trial,wave_data,sampler)
    print(f'{energy:.6f} \t {orb_cr:.6f}')

-3.329485 	 -0.041841
-3.329980 	 -0.042113
-3.329293 	 -0.041641
-3.329412 	 -0.041880
-3.328884 	 -0.041367
-3.329615 	 -0.041967
-3.328886 	 -0.041267
-3.329381 	 -0.041889
-3.329848 	 -0.041779
-3.328972 	 -0.041936


In [14]:
def _frg_zero_cr(walker,trial,wave_data,h0_E0):
    '''
    zeroth order correction
    h0-E0 term of the fragment energy
    '''
    mod_olp = lno_ccsd._frg_modified_ccsd_olp_restricted(walker,wave_data)
    ccsd_olp = trial._calc_overlap_restricted(walker, wave_data)
    zero_cr = jnp.real(h0_E0*mod_olp/ccsd_olp)
    return zero_cr

def frg_zero_cr(walkers,trial,wave_data,h0_E0) -> jax.Array:

    zero_cr = vmap(_frg_zero_cr, in_axes=(0, None, None, None))(
        walkers, trial, wave_data, h0_E0)

    return zero_cr


In [15]:
_frg_zero_cr(prop_data["walkers"][0],trial,wave_data,8.35842125688341)

Array(0., dtype=float64)

In [16]:
frg_zero_cr(prop_data["walkers"],trial,wave_data,8.35842125688341)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float64)

In [24]:
print(sum(prop_data["weights"]))

49.921928532819386


In [30]:
ecorr_ccsd[0]

-0.04172052266340458

In [22]:
print(hf_orb_cr,ccsd_orb_cr,olp_ratio,wt)

-0.03238653807853634 -0.1490909363194065 0.9620427586098549 999.5256693501076


In [6]:
(mf.energy_nuc()-mf.e_tot)

8.35842125688341

In [None]:
(mf.energy_nuc()-mf.e_tot)*(1-olp_ratio)/3+hf_orb_cr+ccsd_orb_cr

Array(-0.07572327, dtype=float64)

In [32]:
def walker_frg_std(prop_data,ham_data,wave_data,trial):
    ccsd_cr_sp = lno_ccsd.frg_ccsd_orb_cr(
        prop_data["walkers"],ham_data,wave_data,trial,1e-5)
    wt_sp = prop_data["weights"]
    ccsd_cr = jnp.sum(ccsd_cr_sp*wt_sp)/jnp.sum(wt_sp)
    ccsd_cr_std = jnp.sqrt(jnp.sum((ccsd_cr_sp-ccsd_cr)**2
                                   *wt_sp)/jnp.sum(wt_sp))
    return ccsd_cr, ccsd_cr_std

In [136]:
hf_elec = mf.energy_elec()[0]
print('hf orbital correction: ',hf_orb_cr_frag)
print('ccsd orbital correction: ',ccsd_orb_cr_frag)
print('overlap frag: ',olp_ratio_frag)
hf_cr = jnp.sum(hf_orb_cr_frag)
print('total hf correction: ',hf_cr)
ccsd_cr = jnp.sum(ccsd_orb_cr_frag)
print('total ccsd correction: ',ccsd_cr)
olp_ratio = jnp.prod(olp_ratio_frag)**(1/nfrag)
print('total overlap ration: ',olp_ratio)
e_corr = jnp.real(-hf_elec*(1-olp_ratio)+hf_cr+ccsd_cr)
print('total correction',e_corr)

hf orbital correction:  [-0.00996635+5.34090730e-05j -0.00863478-2.51830908e-05j
 -0.00855371-9.40076385e-05j]
ccsd orbital correction:  [-0.05152426-0.00059426j -0.05398796-0.00223225j -0.05399157-0.00228491j]
overlap frag:  [0.98815963-1.14789992e-05j 0.98926828-7.28736397e-05j
 0.98976098-3.45241774e-04j]
total hf correction:  (-0.02715483933804915-6.578165636273012e-05j)
total ccsd correction:  (-0.15950379322187802-0.005111421670194715j)
total overlap ration:  (0.9890627465294073-0.00014311540007165075j)
total correction -0.09524046065940323


In [None]:
# H6  ccpvdz
# 100 walkers
# propagated 10*dt steps
### thresh ### sum(ccsd_orb_cr)  ### std_mean ###
### 1e-3   ### -0.1106004004972  +/- 0.001497 ###
### 5e-4   ### -0.1441026474524  +/- 0.001430 ###
### 1e-4   ### -0.1945364116683  +/- 0.002883 ###
### 5e-5   ### -0.1980519867077  +/- 0.002686 ###
### 1e-5   ### -0.1967053505585  +/- 0.002390 ###
### 5e-6   ### -0.1959884369450  +/- 0.002299 ###
### 1e-6   ### -0.2044927850568  +/- 0.002934 ###
### 1e-7   ### -0.2010146842649  +/- 0.002461 ### (all active)
#################################################
###       full afqmc ccsd orb correction      ###
###      -0.1970800087688 +/- 0.00264436      ###
#################################################

In [None]:
# H6  ccpvdz
# 100 walkers 
# propagated 10 steps
### thresh ### sum(ccsd_orb_cr)  ### std_mean ###
### 5e-3   ### -0.0586291158568  +/- 0.000174 ###
### 1e-3   ### -0.1145982676340  +/- 0.002202 ###
### 5e-4   ### -0.1445928123278  +/- 0.002011 ###
### 1e-4   ### -0.1932868509642  +/- 0.003631 ###
### 5e-5   ### -0.1903748674934  +/- 0.003383 ###
### 1e-5   ### -0.2030035185779  +/- 0.003750 ###
### 5e-6   ### -0.1986708389975  +/- 0.003022 ###
### 1e-6   ### -0.1925713846568  +/- 0.003103 ###
#################################################
###       full afqmc ccsd orb correction      ###
###       -0.19935798681758174  +/-  0.00347  ###
#################################################

In [None]:
# 100 walkers 
# propagated 10 steps
### thresh ### sum(ccsd_orb_cr)  ### std_mean ###
### 5e-3   ### -0.0676485716177  +/- 0.000149 ###
### 1e-3   ### -0.0843427544721  +/- 0.001266 ###
### 5e-4   ### -0.1127718966548  +/- 0.002356 ###
### 1e-4   ### -0.1140461406550  +/- 0.002385 ###
### 5e-5   ### -0.1316369973965  +/- 0.003507 ###
### 1e-5   ### -0.1316370067256  +/- 0.003507 ###
#################################################
###        full afqmc ccsd correlation        ###
###       -0.1360351259078871  +/-  0.005016  ###
#################################################