In [1]:
from pyscf import gto, scf, mp, cc

# a = 2 # bond length in a cluster
# d = 3 # distance between each cluster
# unit = 'b' # unit of length
# na = 3  # size of a cluster (monomer)
# nc = 2 # set as integer multiple of monomers
# spin = 1 # spin per monomer
# frozen = 0 # frozen orbital per monomer
# elmt = 'H'
# basis = 'sto6g'
# atoms = ""
# for n in range(nc*na):
#     shift = ((n - n % na) // na) * (d-a)
    # atoms += f"{elmt} {n*a+shift:.5f} 0.00000 0.00000 \n"

atoms = '''
O  -1.551007  -0.114520   0.000000
H  -1.934259   0.762503   0.000000
H  -0.599677   0.040712   0.000000
O   1.350625   0.111469   0.000000
H   1.680398  -0.373741  -0.758561
H   1.680398  -0.373741   0.758561
'''
basis = 'ccpvdz'

mol = gto.M(atom=atoms, basis=basis, spin=0, verbose=4, max_memory=16000)
mf = scf.UHF(mol).density_fit()
mf.kernel()

nfrozen = 2

from pyscf import lo
import numpy as np
orbocca = mf.mo_coeff[0][:,nfrozen:np.count_nonzero(mf.mo_occ[0])]
orbloca = lo.PipekMezey(mol, orbocca).kernel()
orboccb = mf.mo_coeff[1][:,nfrozen:np.count_nonzero(mf.mo_occ[1])]
orblocb = lo.PipekMezey(mol, orboccb).kernel()
lo_coeff = [orbloca, orblocb]

oa = [[[i],[]] for i in range(orbloca.shape[1])]
ob = [[[],[i]] for i in range(orblocb.shape[1])]
frag_lolist = oa + ob

System: uname_result(system='Linux', node='sharmagroup-rn', release='6.14.0-37-generic', version='#37~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov 20 10:25:38 UTC 2', machine='x86_64')  Threads 16
Python 3.11.14 (main, Oct 21 2025, 18:31:21) [GCC 11.2.0]
numpy 2.3.1  scipy 1.16.2  h5py 3.14.0
Date: Thu Dec 11 15:09:58 2025
PySCF version 2.11.0
PySCF path  /home/sharmagroup/sharmagroup/pyscf
GIT HEAD (branch master) 3d1768f5e33b144b606c3d2c81c12ee54d794501

[ENV] PYSCF_EXT_PATH /home/sharmagroup/sharmagroup/pyscf-forge
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 20
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 O     -1.551007000000  -0.114520000000   0.000000000000 AA   -2.930978447283  -0.216411435785   0.00000000000

In [2]:
from collections.abc import Iterable
from pyscf.lno import ulnoccsd
from ad_afqmc.lno_afqmc import ulno_afqmc

thresh = 1e-5
run_frg_list = [0]
chol_cut = 1e-5

mlno = ulnoccsd.ULNOCCSD_T(mf, lo_coeff, frag_lolist, frozen=nfrozen).set(verbose=4)
mlno.lno_thresh = [thresh*10,thresh]
lno_thresh = mlno.lno_thresh
lno_type = ['1h','1h']
lno_thresh = [1e-5, 1e-6] if lno_thresh is None else lno_thresh
print(lno_thresh)
lno_pct_occ = None
lno_norb = None
lo_proj_thresh = 1e-10
lo_proj_thresh_active = 0.1
eris = None

if run_frg_list is not None:
    frag_lolist = [frag_lolist[i] for i in run_frg_list]

nfrag = len(frag_lolist)
if lno_pct_occ is None:
    lno_pct_occ = [None, None]
if lno_norb is None:
    lno_norb = [[None,None]] * nfrag
mf = mlno._scf

if eris is None: eris = mlno.ao2mo()

# Loop over fragment
# frag_res = [None] * nfrag
for ifrag, loidx in enumerate(frag_lolist):
    if len(loidx) == 2 and isinstance(loidx[0], Iterable): # Unrestricted
        orbloc = [lo_coeff[0][:,loidx[0]], lo_coeff[1][:,loidx[1]]]
        lno_param = [
            [
                {
                    'thresh': (
                        lno_thresh[i][s] if isinstance(lno_thresh[i], Iterable)
                        else lno_thresh[i]
                    ),
                    'pct_occ': (
                        lno_pct_occ[i][s] if isinstance(lno_pct_occ[i], Iterable)
                        else lno_pct_occ[i]
                    ),
                    'norb': (
                        lno_norb[ifrag][i][s] if isinstance(lno_norb[ifrag][i], Iterable)
                        else lno_norb[ifrag][i]
                    ),
                } for i in [0, 1]
            ] for s in range(2)
        ]

    else:
        orbloc = lo_coeff[:,loidx]
        lno_param = [{'thresh': lno_thresh[i], 'pct_occ': lno_pct_occ[i],
                        'norb': lno_norb[ifrag][i]} for i in [0,1]]

    lno_coeff, frozen, uocc_loc, frag_msg = mlno.make_las(eris, orbloc, lno_type, lno_param)
    # frag_res[ifrag], frag_msg = mlno.impurity_solve(mf, mo_coeff, uocc_loc, eris, frozen=frozen)
        
        # return (emp2,eccsd,ept)
        # run_afqmc(options,nproc)
        # os.system(f'mv afqmc.out lnoafqmc.out{ifrag+1}')


[0.0001, 1e-05]
LO occ proj: 1 active | 0 standby | 7 orthogonal
LO occ proj: 0 active | 0 standby | 8 orthogonal


In [3]:
frozen

[array([ 0,  1,  2,  3,  4,  5, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
        38, 39, 40, 41, 42, 43, 44, 45, 46, 47]),
 array([ 0,  1,  2,  3,  4,  5, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
        38, 39, 40, 41, 42, 43, 44, 45, 46, 47])]

In [4]:
mol = mf.mol
nocc_a = int(sum(mf.mo_occ[0]))
actfrag_a = np.array([i for i in range(mol.nao) if i not in frozen[0]])
frzocc_a = np.array([i for i in range(nocc_a) if i in frozen[0]])
actocc_a = np.array([i for i in range(nocc_a) if i in actfrag_a])
actvir_a = np.array([i for i in range(nocc_a,mol.nao) if i in actfrag_a])
nfrzocc_a = len(frzocc_a)
nactocc_a = len(actocc_a)
nactvir_a = len(actvir_a)
nactorb_a = len(actfrag_a)
nocc_b = int(sum(mf.mo_occ[1]))
actfrag_b = np.array([i for i in range(mol.nao) if i not in frozen[1]])
frzocc_b = np.array([i for i in range(nocc_b) if i in frozen[1]])
actocc_b = np.array([i for i in range(nocc_b) if i in actfrag_b])
actvir_b = np.array([i for i in range(nocc_b,mol.nao) if i in actfrag_b])
nfrzocc_b = len(frzocc_b)
nactocc_b = len(actocc_b)
nactvir_b = len(actvir_b)
nactorb_b = len(actfrag_b)

ncas = (nactorb_a, nactorb_b)
ncore = (nfrzocc_a, nfrzocc_b)
nelec = (nactocc_a, nactocc_b)
print(ncas)
print(ncore)
print(nelec)

(21, 21)
(6, 6)
(4, 4)


In [5]:
mo_occ = mlno.mo_occ
mlno.verbose_imp = 4
frozen, maskact = ulnoccsd.get_maskact(frozen, [mo_occ[0].size, mo_occ[1].size])
mcc = ulnoccsd.UCCSD(mf, mo_coeff=lno_coeff, frozen=frozen).set(verbose=mlno.verbose_imp)
mcc._s1e = mlno._s1e
mcc._h1e = mlno._h1e
mcc._vhf = mlno._vhf
if mlno.kwargs_imp is not None:
    mcc = mcc.set(**mlno.kwargs_imp)

(emp2,eccsd,ept), t1, t2, prjlo =\
    ulno_afqmc.ulno_ccsd(mcc, lno_coeff, uocc_loc, mo_occ, maskact, ccsd_t=True)
print(emp2,eccsd,ept)


WARN: CCSD detected DF being used in the HF object. MO integrals are computed based on the DF 3-index tensors.
It's recommended to use dfccsd.CCSD for the DF-CCSD calculations

Init t2, MP2 energy = -0.201885564166394

******** <class 'pyscf.lno.ulnoccsd.MODIFIED_UCCSD'> ********
CC2 = 0
CCSD nocc = (np.int64(4), np.int64(4)), nmo = (21, 21)
frozen orbitals [array([ 0,  1,  2,  3,  4,  5, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
       38, 39, 40, 41, 42, 43, 44, 45, 46, 47]), array([ 0,  1,  2,  3,  4,  5, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
       38, 39, 40, 41, 42, 43, 44, 45, 46, 47])]
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-06
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 16000 MB (current use 313 MB)
Init E_corr(MODIFIED_UCCSD) = -0.201885564166394
cycle = 1  E_corr(MODIFIED_UCCSD) = -0.206558458468376  dE = -0.0046728943  norm(t1,t2) = 0.0318274
cycle = 2  E_corr(MODIFIED_UCCSD) = -0.209732406374202  dE = -0.003

In [6]:
options = {'n_eql': 5,
        'n_prop_steps': 50,
        'n_ene_blocks': 1,
        'n_sr_blocks': 10,
        'n_blocks': 10,
        'n_walkers': 1,
        'seed': 98,
        'walker_type': 'uhf',
        'trial': 'uccsd_pt_ad',
        'dt':0.005,
        'free_projection':False,
        'ad_mode':None,
        'use_gpu': False,
        'max_error': 1e-4
        }

from mpi4py import MPI
if not MPI.Is_finalized():
    MPI.Finalize()

nelec, ncas = ulno_afqmc.prep_afqmc(
    mf,lno_coeff,t1,t2,frozen,prjlo,
    options,chol_cut=chol_cut)

print(nelec, ncas)

# Calculating Effective Active Space One-electron Integrals
# Generating Cholesky Integrals
# Composing AO ERIs from DF basis
# Composing active space MO ERIs from AO ERIs
# Finished calculating Cholesky integrals
# Size of the correlation space
# Number of electrons: (4, 4)
# Number of basis functions: (21, 21)
# Alpha Basis Cholesky shape: (232, 21, 21)
#  Beta Basis Cholesky shape: (232, 21, 21)
(4, 4) (21, 21)


In [1]:
from ad_afqmc import config
config.setup_jax()
MPI = config.setup_comm()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
from ad_afqmc.lno_afqmc import ulno_afqmc
ham_data, prop, trial, wave_data, sampler, options, _ = (
    ulno_afqmc._prep_afqmc())

# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: (21, 21)
# nelec: (4, 4)
#
# n_eql: 5
# n_prop_steps: 50
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 1
# seed: 98
# walker_type: uhf
# trial: uccsd_pt_ad
# dt: 0.005
# free_projection: False
# use_gpu: False
# max_error: 0.0001
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# n_batch: 1
#


In [2]:
import time
from jax import numpy as jnp
from jax import random
init_time = time.time()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

### initialize propagation
seed = options["seed"]
init_walkers = None
# dm_up = jnp.array(wave_data["mo_coeff"][0] @ wave_data["mo_coeff"][0].T.conj())
# dm_dn = jnp.array(wave_data["mo_coeff"][1] @ wave_data["mo_coeff"][1].T.conj())
# trial_rdm1 = [dm_up, dm_dn]
trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1
ham_data = trial._build_measurement_intermediates(ham_data, wave_data)
ham_data = prop._build_propagation_intermediates(ham_data, trial, wave_data)

prop_data = prop.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]
print(prop_data["e_estimate"])
# E(MODIFIED_UCCSD) = -3.085067447898267  E_corr = -0.01684699454462395
# LNO-UCCSD (T1 = 0) fragment energy: -0.008283544474655563


-152.06249064690837


In [3]:
walker_up, walker_dn = prop_data['walkers'][0][0], prop_data['walkers'][1][0]
eorb, teorb, torb, ecorr = trial._calc_eorb_pt(walker_up, walker_dn,ham_data,wave_data)
print(eorb, teorb, torb, ecorr)

0.0 -0.023489784995926567 0.0 -152.06249064690837


In [20]:
import numpy as np
amp_file="./amplitudes.npz"
amplitudes = np.load(amp_file)
t1a = jnp.array(amplitudes["t1a"])
t1b = jnp.array(amplitudes["t1b"])
t2aa = jnp.array(amplitudes["t2aa"])
t2ab = jnp.array(amplitudes["t2ab"])
t2bb = jnp.array(amplitudes["t2bb"])
t2ba = t2ab.transpose(2, 3, 0, 1)
wave_data["t1a"], wave_data["t1b"] = t1a, t1b
wave_data["t2aa"], wave_data["t2bb"] = t2aa, t2bb
wave_data["t2ab"], wave_data["t2ba"] = t2ab, t2ba

In [None]:
import opt_einsum as oe
# prja, prjb = wave_data['prjlo']
# wave_data["t1a"], wave_data["t1b"] = t1a, t1b
# wave_data["t2aa"], wave_data["t2bb"] = t2aa, t2bb
# wave_data["t2ab"], wave_data["t2ba"] = t2ab, t2ba
# t1a_prj = oe.contract('ia,ik->ka',t1a,prja,backend='jax')
# t1b_prj = oe.contract('ia,ik->ka',t1b,prjb,backend='jax')
# t2aa_prj = oe.contract('iajb,ik->kajb',t2aa,prja,backend='jax')
# t2ab_prj = oe.contract('iajb,ik->kajb',t2ab,prja,backend='jax')
# t2ba_prj = oe.contract('iajb,ik->kajb',t2ba,prjb,backend='jax')
# t2bb_prj = oe.contract('iajb,ik->kajb',t2bb,prjb,backend='jax')
# wave_data["t1a_prj"], wave_data["t1b_prj"] = t1a_prj, t1b_prj
# wave_data["t2aa_prj"], wave_data["t2ab_prj"] = t2aa_prj, t2ab_prj
# wave_data["t2ba_prj"], wave_data["t2bb_prj"] = t2ba_prj, t2bb_prj

In [26]:
import jax
from jax import vmap, lax
def _calc_energy_pt_ref(
    trial,
    walker_up: jax.Array,
    walker_dn: jax.Array,
    ham_data: dict,
    wave_data: dict,
) -> complex:
    norb_a, norb_b = trial.norb
    nocc_a, nocc_b = trial.nelec
    t1_a, t2_aa = wave_data["t1a"], wave_data["t2aa"]
    t1_b, t2_bb = wave_data["t1b"], wave_data["t2bb"]
    t2_ab = wave_data["t2ab"]
    green_a = (walker_up.dot(jnp.linalg.inv(walker_up[:nocc_a, :]))).T
    green_b = (walker_dn.dot(jnp.linalg.inv(walker_dn[:nocc_b, :]))).T
    green_occ_a = green_a[:, nocc_a:].copy()
    green_occ_b = green_b[:, nocc_b:].copy()
    greenp_a = jnp.vstack((green_occ_a, -jnp.eye(norb_a - nocc_a)))
    greenp_b = jnp.vstack((green_occ_b, -jnp.eye(norb_b - nocc_b)))

    chol_a = ham_data["chol"][0].reshape(-1, norb_a, norb_a)
    chol_b = ham_data["chol"][1].reshape(-1, norb_b, norb_b)
    rot_chol_a = chol_a[:, :nocc_a, :]
    rot_chol_b = chol_b[:, :nocc_b, :]
    h1_a = ham_data["h1"][0]
    h1_b = ham_data["h1"][1]
    hg_a = oe.contract("pj,pj->", h1_a[:nocc_a, :], green_a, backend="jax")
    hg_b = oe.contract("pj,pj->", h1_b[:nocc_b, :], green_b, backend="jax")
    hg = hg_a + hg_b

    # 0 body energy
    h0 = ham_data["h0"]

    # 1 body energy
    # ref
    e1_0 = hg # <HF|h1|walker>/<HF|walker>

    # single excitations
    t1g_a = oe.contract("ia,ia->", t1_a, green_occ_a, backend="jax")
    t1g_b = oe.contract("ia,ia->", t1_b, green_occ_b, backend="jax")
    t1g = t1g_a + t1g_b
    e1_1_1 = t1g * hg
    gpt1_a = greenp_a @ t1_a.T
    gpt1_b = greenp_b @ t1_b.T
    t1_green_a = gpt1_a @ green_a
    t1_green_b = gpt1_b @ green_b
    e1_1_2 = -(
        oe.contract("pq,pq->", h1_a, t1_green_a, backend="jax")
        + oe.contract("pq,pq->", h1_b, t1_green_b, backend="jax")
    )
    e1_1 = e1_1_1 + e1_1_2 # <HF|T1 h1|walker>/<HF|walker>

    # double excitations
    t2g_a = oe.contract("ptqu,pt->qu", t2_aa, green_occ_a, backend="jax") / 4
    t2g_b = oe.contract("ptqu,pt->qu", t2_bb, green_occ_b, backend="jax") / 4
    t2g_ab_a = oe.contract("ptqu,qu->pt", t2_ab, green_occ_b, backend="jax")
    t2g_ab_b = oe.contract("ptqu,pt->qu", t2_ab, green_occ_a, backend="jax")
    gt2g_a = oe.contract("qu,qu->", t2g_a, green_occ_a, backend="jax")
    gt2g_b = oe.contract("qu,qu->", t2g_b, green_occ_b, backend="jax")
    gt2g_ab = oe.contract("pt,pt->", t2g_ab_a, green_occ_a, backend="jax")
    gt2g = 2 * (gt2g_a + gt2g_b) + gt2g_ab
    e1_2_1 = hg * gt2g
    t2_green_a = (greenp_a @ t2g_a.T) @ green_a
    t2_green_ab_a = (greenp_a @ t2g_ab_a.T) @ green_a
    t2_green_b = (greenp_b @ t2g_b.T) @ green_b
    t2_green_ab_b = (greenp_b @ t2g_ab_b.T) @ green_b
    e1_2_2_a = -oe.contract(
        "ij,ij->", h1_a, 4 * t2_green_a + t2_green_ab_a, backend="jax"
    )
    e1_2_2_b = -oe.contract(
        "ij,ij->", h1_b, 4 * t2_green_b + t2_green_ab_b, backend="jax"
    )
    e1_2_2 = e1_2_2_a + e1_2_2_b
    e1_2 = e1_2_1 + e1_2_2 # <HF|T2 h1|walker>/<HF|walker>

    # two body energy
    # ref
    lg_a = oe.contract("gpj,pj->g", rot_chol_a, green_a, backend="jax")
    lg_b = oe.contract("gpj,pj->g", rot_chol_b, green_b, backend="jax")
    e2_0_1 = ((lg_a + lg_b) @ (lg_a + lg_b)) / 2.0
    lg1_a = oe.contract("gpj,qj->gpq", rot_chol_a, green_a, backend="jax")
    lg1_b = oe.contract("gpj,qj->gpq", rot_chol_b, green_b, backend="jax")
    e2_0_2 = (
        -(
            jnp.sum(vmap(lambda x: x * x.T)(lg1_a))
            + jnp.sum(vmap(lambda x: x * x.T)(lg1_b))
        )
        / 2.0
    )
    e2_0 = e2_0_1 + e2_0_2 # <HF|h2|walker>/<HF|walker>

    # single excitations
    e2_1_1 = e2_0 * t1g
    lt1g_a = oe.contract("gij,ij->g", chol_a, t1_green_a, backend="jax")
    lt1g_b = oe.contract("gij,ij->g", chol_b, t1_green_b, backend="jax")
    e2_1_2 = -((lt1g_a + lt1g_b) @ (lg_a + lg_b))
    t1g1_a = t1_a @ green_occ_a.T
    t1g1_b = t1_b @ green_occ_b.T
    e2_1_3_1 = oe.contract(
        "gpq,gqr,rp->", lg1_a, lg1_a, t1g1_a, backend="jax"
    ) + oe.contract("gpq,gqr,rp->", lg1_b, lg1_b, t1g1_b, backend="jax")
    lt1g_a = oe.contract(
        "gip,qi->gpq", ham_data["lt1_a"], green_a, backend="jax"
    )
    lt1g_b = oe.contract(
        "gip,qi->gpq", ham_data["lt1_b"], green_b, backend="jax"
    )
    e2_1_3_2 = -oe.contract(
        "gpq,gqp->", lt1g_a, lg1_a, backend="jax"
    ) - oe.contract("gpq,gqp->", lt1g_b, lg1_b, backend="jax")
    e2_1_3 = e2_1_3_1 + e2_1_3_2
    e2_1 = e2_1_1 + e2_1_2 + e2_1_3 # <HF|T1 h2|walker>/<HF|walker>

    # double excitations
    e2_2_1 = e2_0 * gt2g
    lt2g_a = oe.contract(
        "gij,ij->g",
        chol_a,
        8 * t2_green_a + 2 * t2_green_ab_a,
        backend="jax",
    )
    lt2g_b = oe.contract(
        "gij,ij->g",
        chol_b,
        8 * t2_green_b + 2 * t2_green_ab_b,
        backend="jax",
    )
    e2_2_2_1 = -((lt2g_a + lt2g_b) @ (lg_a + lg_b)) / 2.0

    def scanned_fun(carry, x):
        chol_a_i, rot_chol_a_i, chol_b_i, rot_chol_b_i = x
        gl_a_i = oe.contract("ir,pr->ip", green_a, chol_a_i, backend="jax")
        gl_b_i = oe.contract("ir,pr->ip", green_b, chol_b_i, backend="jax")
        lt2_green_a_i = oe.contract(
            "pi,ji->pj",
            rot_chol_a_i,
            8 * t2_green_a + 2 * t2_green_ab_a,
            backend="jax",
        )
        lt2_green_b_i = oe.contract(
            "pi,ji->pj",
            rot_chol_b_i,
            8 * t2_green_b + 2 * t2_green_ab_b,
            backend="jax",
        )
        carry[0] += 0.5 * (
            oe.contract("pi,pi->", gl_a_i, lt2_green_a_i, backend="jax")
            + oe.contract("pi,pi->", gl_b_i, lt2_green_b_i, backend="jax")
        )
        glgp_a_i = oe.contract(
            "pi,it->pt", gl_a_i, greenp_a, backend="jax"
        )
        glgp_b_i = oe.contract(
            "pi,it->pt", gl_b_i, greenp_b, backend="jax"
        )
        l2t2_a = 0.5 * oe.contract(
            "pt,qu,ptqu->",
            glgp_a_i,
            glgp_a_i,
            t2_aa,
            backend="jax",
        )
        l2t2_b = 0.5 * oe.contract(
            "pt,qu,ptqu->",
            glgp_b_i,
            glgp_b_i,
            t2_bb,
            backend="jax",
        )
        l2t2_ab = oe.contract(
            "pt,qu,ptqu->",
            glgp_a_i,
            glgp_b_i,
            t2_ab,
            backend="jax",
        )
        carry[1] += l2t2_a + l2t2_b + l2t2_ab
        return carry, 0.0

    [e2_2_2_2, e2_2_3], _ = lax.scan(
        scanned_fun, [0.0, 0.0], (chol_a, rot_chol_a, chol_b, rot_chol_b)
    )
    e2_2_2 = e2_2_2_1 + e2_2_2_2
    e2_2 = e2_2_1 + e2_2_2 + e2_2_3 # <HF|T2 h2|walker>/<HF|walker>

    t = t1g + gt2g # <HF|T1+T2|walker>/<HF|walker>
    te = e1_1 + e1_2 + e2_1 + e2_2 # <HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    e0 = h0 + e1_0 + e2_0 # h0 + <HF|h1+h2|walker>/<HF|walker>
    # return jnp.real(t), jnp.real(e0), jnp.real(e1_1), jnp.real(e1_2), jnp.real(e2_1), jnp.real(e2_2), jnp.real(e2_2_2_1), jnp.real(e2_2_2_2), jnp.real(e2_2_3)
    return jnp.real(t), jnp.real(te), jnp.real(e0)

In [30]:
import jax
from jax import lax, jvp, vmap
import opt_einsum as oe

def _calc_energy_pt(
    trial,
    walker_up: jax.Array,
    walker_dn: jax.Array,
    ham_data: dict,
    wave_data: dict,
) -> complex:
    norb_a, norb_b = trial.norb
    nocc_a, nocc_b = trial.nelec
    h0, E0 = ham_data["h0"], ham_data['E0']
    h1a, h1b = ham_data["h1"]
    t1a, t1b = wave_data["t1a"], wave_data["t1b"]
    t2aa, t2ab = wave_data["t2aa"], wave_data["t2ab"]
    t2ba, t2bb = wave_data["t2ba"], wave_data["t2bb"]
    chol_a, chol_b = ham_data["chol"]
    chol_a = chol_a.reshape(-1, norb_a, norb_a)
    chol_b = chol_b.reshape(-1, norb_b, norb_b)
    rot_chol_a = chol_a[:, :nocc_a, :]
    rot_chol_b = chol_b[:, :nocc_b, :]

    green_a = (walker_up.dot(jnp.linalg.inv(walker_up[:nocc_a, :]))).T # G_ip
    green_b = (walker_dn.dot(jnp.linalg.inv(walker_dn[:nocc_b, :]))).T
    green_occ_a = green_a[:, nocc_a:].copy() # G_ia
    green_occ_b = green_b[:, nocc_b:].copy()
    greenp_a = jnp.vstack((green_occ_a, -jnp.eye(norb_a - nocc_a)))
    greenp_b = jnp.vstack((green_occ_b, -jnp.eye(norb_b - nocc_b)))

    # 1 body energy    
    hg_a = oe.contract("pj,pj->", h1a[:nocc_a, :], green_a, backend="jax")
    hg_b = oe.contract("pj,pj->", h1b[:nocc_b, :], green_b, backend="jax")
    e1_0 = hg_a + hg_b #  <HF|h1|walker>/<HF|walker>

    # single excitations = t_ia (G_ia G_pq - G_iq Gp_pa) h_pq
    t1g_a = oe.contract("ia,ia->", t1a, green_occ_a, backend="jax")
    t1g_b = oe.contract("ia,ia->", t1b, green_occ_b, backend="jax")
    t1g = t1g_a + t1g_b
    e1_1_1 = t1g * e1_0
    t1_green_a = oe.contract("pa,ia,iq->pq", greenp_a, t1a, green_a, backend="jax")
    t1_green_b = oe.contract("pa,ia,iq->pq", greenp_b, t1b, green_b, backend="jax")
    e1_1_2 = -(oe.contract("pq,pq->", t1_green_a, h1a, backend="jax")
               + oe.contract("pq,pq->", t1_green_b, h1b, backend="jax"))
    e1_1 = e1_1_1 + e1_1_2 # <HF|T1 h1|walker>/<HF|walker>

    # double excitations
    t2g_a = oe.contract("iajb,ia->jb", t2aa, green_occ_a, backend="jax") / 4
    t2g_b = oe.contract("iajb,ia->jb", t2bb, green_occ_b, backend="jax") / 4
    t2g_ab_a = oe.contract("iajb,ia->jb", t2ab, green_occ_a, backend="jax") / 2
    t2g_ab_b = oe.contract("iajb,jb->ia", t2ab, green_occ_b, backend="jax") / 2
    t2g_ba_a = oe.contract("iajb,jb->ia", t2ba, green_occ_a, backend="jax") / 2
    t2g_ba_b = oe.contract("iajb,ia->jb", t2ba, green_occ_b, backend="jax") / 2
    gt2g_aa = oe.contract("jb,jb->", t2g_a, green_occ_a, backend="jax")
    gt2g_bb = oe.contract("jb,jb->", t2g_b, green_occ_b, backend="jax")
    gt2g_ab = oe.contract("jb,jb->", t2g_ab_a, green_occ_b, backend="jax")
    gt2g_ba = oe.contract("jb,jb->", t2g_ba_b, green_occ_a, backend="jax")
    gt2g = 2 * (gt2g_aa + gt2g_bb) + (gt2g_ab + gt2g_ba)
    e1_2_1 = gt2g * e1_0
    # t_iajb G_ia G_jq Gp_pb
    t2_green_aaa = oe.contract('pb,jb,jq->pq', greenp_a, t2g_a, green_a, backend="jax")
    t2_green_bbb = oe.contract('pb,jb,jq->pq', greenp_b, t2g_b, green_b, backend="jax")
    t2_green_aba = oe.contract('pa,ia,iq->pq', greenp_a, t2g_ab_b, green_a, backend="jax")
    t2_green_baa = oe.contract('pb,jb,jq->pq', greenp_a, t2g_ba_b, green_a, backend="jax")
    t2_green_bab = oe.contract('pa,ia,iq->pq', greenp_b, t2g_ba_a, green_b, backend="jax")
    t2_green_abb = oe.contract('pb,jb,jq->pq', greenp_b, t2g_ab_a, green_b, backend="jax")
    e1_2_2_a = -oe.contract(
        "pq,pq->", 4*t2_green_aaa + t2_green_aba + t2_green_baa, h1a, backend="jax")
    e1_2_2_b = -oe.contract(
        "pq,pq->", 4*t2_green_bbb + t2_green_bab + t2_green_abb, h1b, backend="jax")
    e1_2_2 = e1_2_2_a + e1_2_2_b
    e1_2 = e1_2_1 + e1_2_2 # <HF|T2 h1|walker>/<HF|walker>

    # two body energy
    lg_a = oe.contract("gpj,qj->gpq", rot_chol_a, green_a, backend="jax")
    lg_b = oe.contract("gpj,qj->gpq", rot_chol_b, green_b, backend="jax")
    tr_lg_a = oe.contract("gpp->g", lg_a, backend="jax")
    tr_lg_b = oe.contract("gpp->g", lg_b, backend="jax")
    lg_0 = tr_lg_a + tr_lg_b
    e2_0_1 = oe.contract('g,g->', lg_0, lg_0) / 2.0
    e2_0_2 = - (oe.contract("gpq,gqp->", lg_a, lg_a, backend="jax")
                + oe.contract("gpq,gqp->", lg_b, lg_b, backend="jax")) / 2.0
    e2_0 = e2_0_1 + e2_0_2 # <HF|h2|walker>/<HF|walker>

    # single excitations
    e2_1_1 = e2_0 * t1g
    lt1g_a = oe.contract("gpq,pq->g", chol_a, t1_green_a, backend="jax")
    lt1g_b = oe.contract("gpq,pq->g", chol_b, t1_green_b, backend="jax")
    e2_1_2 = -((lt1g_a + lt1g_b) @ lg_0)
    t1g1_a = t1a @ green_occ_a.T
    t1g1_b = t1b @ green_occ_b.T
    e2_1_3_1 = oe.contract("gpq,gqr,rp->", lg_a, lg_a, t1g1_a, backend="jax") \
        + oe.contract("gpq,gqr,rp->", lg_b, lg_b, t1g1_b, backend="jax")
    lt1g_a = oe.contract("gip,qi->gpq", ham_data["lt1_a"], green_a, backend="jax")
    lt1g_b = oe.contract("gip,qi->gpq", ham_data["lt1_b"], green_b, backend="jax")
    e2_1_3_2 = -oe.contract("gpq,gqp->", lt1g_a, lg_a, backend="jax") \
        - oe.contract("gpq,gqp->", lt1g_b, lg_b, backend="jax")
    e2_1_3 = e2_1_3_1 + e2_1_3_2
    e2_1 = e2_1_1 + e2_1_2 + e2_1_3 # <HF|T1 h2|walker>/<HF|walker>

    # double excitations
    e2_2_1 = e2_0 * gt2g
    lt2g_a = oe.contract(
        "gpq,pq->g", chol_a, 8*t2_green_aaa + 2*(t2_green_aba + t2_green_baa),
        backend="jax")
    lt2g_b = oe.contract(
        "gpq,pq->g", chol_b, 8*t2_green_bbb + 2*(t2_green_bab + t2_green_abb),
        backend="jax")
    e2_2_2_1 = -((lt2g_a + lt2g_b) @ lg_0) / 2.0

    def scanned_fun(carry, x):
        chol_a_i, rot_chol_a_i, chol_b_i, rot_chol_b_i = x
        gl_a_i = oe.contract("ir,pr->ip", green_a, chol_a_i, backend="jax")
        gl_b_i = oe.contract("ir,pr->ip", green_b, chol_b_i, backend="jax")
        lt2_green_a_i = oe.contract(
            "pi,ji->pj", rot_chol_a_i, 8*t2_green_aaa + 2*(t2_green_aba + t2_green_baa), 
            backend="jax")
        lt2_green_b_i = oe.contract(
            "pi,ji->pj", rot_chol_b_i, 8*t2_green_bbb + 2*(t2_green_bab + t2_green_abb),
            backend="jax")
        carry[0] += (oe.contract("ip,ip->", gl_a_i, lt2_green_a_i, backend="jax")
                     + oe.contract("ip,ip->", gl_b_i, lt2_green_b_i, backend="jax")) / 2
        glgp_a_i = oe.contract("ip,pa->ia", gl_a_i, greenp_a, backend="jax")
        glgp_b_i = oe.contract("ip,pa->ia", gl_b_i, greenp_b, backend="jax")
        l2t2_aa = 0.5 * oe.contract(
            "ia,jb,iajb->", glgp_a_i, glgp_a_i, t2aa, backend="jax")
        l2t2_ab = 0.5 * oe.contract(
            "ia,jb,iajb->", glgp_a_i, glgp_b_i, t2ab, backend="jax")
        l2t2_ba = 0.5 * oe.contract(
            "ia,jb,iajb->", glgp_b_i, glgp_a_i, t2ba, backend="jax")
        l2t2_bb = 0.5 * oe.contract(
            "ia,jb,iajb->", glgp_b_i, glgp_b_i, t2bb, backend="jax")
        carry[1] += l2t2_aa + l2t2_ab + l2t2_ba + l2t2_bb
        return carry, 0.0

    [e2_2_2_2, e2_2_3], _ = lax.scan(
        scanned_fun, [0.0, 0.0], (chol_a, rot_chol_a, chol_b, rot_chol_b)
    )
    e2_2_2 = e2_2_2_1 + e2_2_2_2
    e2_2 = e2_2_1 + e2_2_2 + e2_2_3 # <HF|T2 h2|walker>/<HF|walker>

    # torb = t1g + gt2g # <HF|T1+T2|walker>/<HF|walker>
    # ecorr = h0 + e1_0 + e2_0 #- ham_data['E0'] # <HF|H-E0|walker>/<HF|walker>
    # teorb = e1_1 + e1_2 + e2_1 + e2_2 + torb*(h0-ham_data['E0'])# <HF|(T1+T2)(H-E0)|walker>/<HF|walker>

    # return jnp.real(teorb), jnp.real(torb), jnp.real(ecorr)
    t = t1g + gt2g # <HF|T1+T2|walker>/<HF|walker>
    te = e1_1 + e1_2 + e2_1 + e2_2 # <HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    e0 = h0 + e1_0 + e2_0 # h0 + <HF|h1+h2|walker>/<HF|walker>

    return jnp.real(t), jnp.real(te), jnp.real(e0)
    # return jnp.real(t), jnp.real(e0), jnp.real(e1_1), jnp.real(e1_2), jnp.real(e2_1), jnp.real(e2_2), jnp.real(e2_2_2_1), jnp.real(e2_2_2_2), jnp.real(e2_2_3)

In [21]:
norba, norbb = trial.norb
nocca, noccb = trial.nelec
ham_data["lt1_a"] = oe.contract(
    "git,pt->gip",
    ham_data["chol"][0].reshape(-1, norba, norba)[:, :, nocca:],
    wave_data["t1a"], backend="jax")
ham_data["lt1_b"] = oe.contract(
    "git,pt->gip",
    ham_data["chol"][1].reshape(-1, norbb, norbb)[:, :, noccb:],
    wave_data["t1b"],backend="jax")

In [32]:
walker_up, walker_dn = prop_data['walkers'][0][0], prop_data['walkers'][1][0]
eorb, te, t, e0 = trial._calc_eorb_pt(walker_up, walker_dn,ham_data,wave_data)
print(eorb, te, torb, e0)

0.0 -0.21090227608978404 0.0 -152.06249064690837


In [None]:
t, te, e0 = _calc_energy_pt_ref(trial,walker_up,walker_dn,ham_data,wave_data)
print(t, te, e0)

0.0 -0.21090202947718203 -152.06249064690834


In [37]:
import numpy as np
walker_up_r = np.random.randn(*walker_up.shape)
walker_dn_r = np.random.randn(*walker_dn.shape)

In [41]:
_, te, t, e0 = trial._calc_eorb_pt(walker_up_r, walker_dn_r,ham_data,wave_data)
print(t, te, e0)

0.3190143225222868 -7.837120518335682 -152.28626510829508


In [39]:
t, te, e0 = _calc_energy_pt_ref(trial,walker_up_r,walker_dn_r,ham_data,wave_data)
print(t, te, e0)

0.3190143225222866 -7.837105016435412 -152.28626510829503


In [40]:
t, te, e0 = _calc_energy_pt(trial,walker_up_r,walker_dn_r,ham_data,wave_data)
print(t, te, e0)

0.3190143225222868 -7.837105016435405 -152.28626510829503
