In [1]:
from functools import partial
from pyscf import gto, scf, cc
import numpy as np

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

a = 1.
nH = 10
atoms = ""
for i in range(nH):
    atoms += f"H {i*a} 0 0 \n"

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mf = scf.RHF(mol).density_fit()
mf.kernel()

nfrozen = 0
mycc = cc.CCSD(mf,frozen=nfrozen)
mycc.kernel()
print(mycc.e_corr)

System: uname_result(system='Linux', node='sharmagroup-rn', release='6.14.0-35-generic', version='#35~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Oct 14 13:55:17 UTC 2', machine='x86_64')  Threads 16
Python 3.11.14 (main, Oct 21 2025, 18:31:21) [GCC 11.2.0]
numpy 2.3.1  scipy 1.16.2  h5py 3.14.0
Date: Fri Nov 21 12:47:34 2025
PySCF version 2.11.0
PySCF path  /home/sharmagroup/sharmagroup/pyscf
GIT HEAD (branch master) 3d1768f5e33b144b606c3d2c81c12ee54d794501

[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 10
[INPUT] num. electrons = 10
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.000000000000 Bohr   0.0
[INPUT]  2 H      1.000000000000   0.000000000000

In [2]:
options = {'n_eql': 4,
           'n_prop_steps': 10,
            'n_ene_blocks': 1,
            'n_sr_blocks': 5,
            'n_blocks': 10,
            'n_walkers': 10,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd_ad',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

In [5]:
import os
from pyscf.ci.cisd import CISD
from pyscf.cc.ccsd import CCSD
from pyscf import lib
from ad_afqmc.lno.cc import LNOCCSD
from ad_afqmc.lno_afqmc import lno_maker, lno_afqmc
from ad_afqmc.lno.base import lno

frozen = nfrozen
thresh = 1e-5
chol_cut = 1e-5
eris = None
run_frg_list = None

mfcc = mf

if isinstance(mfcc, (CCSD, CISD)):
    full_cisd = True
    lnomf = mfcc._scf
else:
    full_cisd = False
    lnomf = mfcc

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(lnomf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = 'boys'
lno_cc.no_type = 'ie'
no_type = 'ie'
lno_cc.frag_lolist = '1o'
lno_cc.force_outcore_ao2mo = True

s1e = lnomf.get_ovlp()
lococc = lno_cc.get_lo(lo_type='boys') # localized active occ orbitals
# lococc,locvir = lno_maker.get_lo(lno_cc,lo_type) ### fix this for DF
if eris is None: eris = lno_cc.ao2mo()

frag_lolist = [[i] for i in range(lococc.shape[1])]
print(frag_lolist)
nfrag = len(frag_lolist)

frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'# lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

frag_nonvlist = None
if frag_nonvlist is None: frag_nonvlist = lno_cc.frag_nonvlist
if frag_nonvlist is None: frag_nonvlist = [[None,None]] * nfrag

eorb_cc = np.empty(nfrag,dtype='float64')
    
from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100*nfrag)

for ifrag in run_frg_list:
    print(f'\n########### running fragment {ifrag+1} ##########')

    fraglo = frag_lolist[ifrag]
    orbfragloc = lococc[:,fraglo]
    THRESH_INTERNAL = 1e-10
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    frzfrag, orbfrag, can_orbfrag \
         = lno.make_fpno1(lno_cc, eris, orbfragloc, no_type,
                            THRESH_INTERNAL, thresh_pno,
                            frozen_mask=frozen_mask,
                            frag_target_nocc=None,
                            frag_target_nvir=None,
                            canonicalize=True)

    mol = mf.mol
    nocc = mol.nelectron // 2 
    nao = mol.nao
    actfrag = np.array([i for i in range(nao) if i not in frzfrag])
    # frzocc = np.array([i for i in range(nocc) if i in frzfrag])
    actocc = np.array([i for i in range(nocc) if i in actfrag])
    actvir = np.array([i for i in range(nocc,nao) if i in actfrag])
    nactocc = len(actocc)
    nactocc = len(actocc)
    nactvir = len(actvir)
    prjlo = orbfragloc.T @ s1e @ orbfrag[:,actocc]
    nelec_act = nactocc*2
    norb_act = nactocc+nactvir

    print(f'# active orbitals: {actfrag}')
    print(f'# active occupied orbitals: {actocc}')
    print(f'# active virtual orbitals: {actvir}')
    print(f'# frozen orbitals: {frzfrag}')
    print(f'# number of active electrons: {nelec_act}')
    print(f'# number of active orbitals: {norb_act}')
    print(f'# number of frozen orbitals: {len(frzfrag)}')

    # mp2 is not invariant to lno transformation
    # needs to be done in canoical HF orbitals
    # which the globel mp2 is calculated in
    print('# running fragment MP2')
    ecorr_p2 = \
        lno_maker.lno_mp2_frg_e(lnomf,frzfrag,orbfragloc,can_orbfrag)
    ecorr_p2 = f'{ecorr_p2:.8f}'
    print(f'# LNO-MP2 Orbital Energy: {ecorr_p2}')
    
    print('# running fragment CCSD')
    mcc, ecorr_cc = \
        lno_maker.lno_cc_solver(lnomf,orbfrag,orbfragloc,frozen=frzfrag)
    eorb_cc[ifrag] = ecorr_cc
    ecorr_cc = f'{ecorr_cc:.8f}'
    print(f'# LNO-CCSD Energy: {mcc.e_tot}')
    print(f'# LNO-CCSD Orbital Energy: {ecorr_cc}')

    ci1 = np.array(mcc.t1)
    ci2 = mcc.t2 + lib.einsum("ia,jb->ijab",ci1,ci1)

    options["seed"] = seeds[ifrag]
    lno_afqmc.prep_lnoafqmc(
        mf,orbfrag,options,
        norb_act=norb_act,nelec_act=nelec_act,
        prjlo=prjlo,norb_frozen=frzfrag,
        ci1=ci1,ci2=ci2,chol_cut=chol_cut,
        )
    lno_afqmc.run_afqmc(options)
    os.system(f'mv lno_afqmc.out lno_afqmc.out{ifrag+1}')


lo_type = boys


******** <class 'pyscf.lo.boys.Boys'> ********
conv_tol = 1e-06
conv_tol_grad = None
max_cycle = 100
max_stepsize = 0.01
max_iters = 20
kf_interval = 5
kf_trust_region = 5
ah_start_tol = 1000000000.0
ah_start_cycle = 1
ah_level_shift = 0
ah_conv_tol = 1e-12
ah_lindep = 1e-14
ah_max_cycle = 40
ah_trust_region = 3
init_guess = atomic
Set conv_tol_grad to 0.000316228
macro= 1  f(x)= 24.313357633615  delta_f= 24.3134  |g|= 10.0357  2 KF 10 Hx
macro= 2  f(x)= 22.495993805957  delta_f= -1.81736  |g|= 4.35347  2 KF 10 Hx
macro= 3  f(x)= 22.252027697998  delta_f= -0.243966  |g|= 0.127638  2 KF 7 Hx
macro= 4  f(x)= 22.252027476411  delta_f= -2.21587e-07  |g|= 0.00282975  1 KF 2 Hx
macro= 5  f(x)= 22.252027476411  delta_f= 0  |g|= 5.82531e-08  1 KF 1 Hx
macro X = 5  f(x)= 22.252027476411  |g|= 5.82531e-08  10 intor 8 KF 30 Hx
Lov is saved to /tmp/vuan0843
[[0], [1], [2], [3], [4]]
# lno thresh [0.0001, 1e-05]

########### running fragment 1 ##########
# active orbitals: [ 0  1  

In [6]:
sum(eorb_cc)

np.float64(-0.22697898626751437)

In [9]:
eo = np.empty(nfrag)
oo = np.empty(nfrag)
for i in run_frg_list:
    with open(f"lno_afqmc.out{i+1}", "r") as rf:
        for line in rf:
            if "AFQMC/CISD E_Orbital" in line:
                eo[i] = line.split()[-3]
            if "AFQMC/CISD Olp_Orbital" in line:
                oo[i] = line.split()[-3]
print(eo)
print(oo)

[-0.058953 -0.052051 -0.050146 -0.046904 -0.053803]
[0.040351 0.020717 0.066304 0.014136 0.043787]


In [10]:
sum(eo)/(1+sum(oo))

np.float64(-0.22092137400393996)

In [1]:
# Final Results
# AFQMC/HF energy: -4.195203 +/- 0.002045
# AFQMC/HF E_Orbital: -0.027976 +/- 0.002045
# AFQMC/CISD E_Orbital: -0.028062 +/- 0.000048
# AFQMC/CISD E12_Orbital: -0.000085 +/- 0.001997
# AFQMC/CISD O12_Orbital: 0.024310 +/- 0.001750
-0.028062/(1+0.024310)

-0.027396003163104917

In [2]:
(-0.045286+ -0.044839 + -0.042857 + -0.042662 + -0.044108) / (1+0.041431+0.041612+0.039334+0.036109+0.037628)

-0.1837216185079349

In [7]:
import jax
from jax import jit, vmap, lax
import opt_einsum as oe

@partial(jit, static_argnums=0)
def _calc_orb_energy_restricted(
    self, walker: jax.Array, ham_data: dict, wave_data: dict
) -> complex:
    ci1, ci2 = wave_data["ci1"], wave_data["ci2"]
    nocc = self.nelec[0]
    prj = ham_data['prj']

    green = (walker.dot(jnp.linalg.inv(walker[:nocc, :]))).T
    green_occ = green[:, nocc:].copy()
    greenp = jnp.vstack((green_occ, -jnp.eye(self.norb - nocc)))

    chol = ham_data["chol"].reshape(-1, self.norb, self.norb)
    rot_chol = chol[:, : self.nelec[0], :]
    h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
    hg = oe.contract("pj,pj->", h1[:nocc, :], green, backend="jax")

    # 0 body energy
    e0 = ham_data["h0"]

    # 1 body energy
    # ref
    e1_0 = 2 * hg

    # single excitations
    ci1g_orb = oe.contract("ia,ka,ik->", ci1, green_occ, prj, backend="jax")
    e1_1_1_orb = 4 * ci1g_orb * hg # c_ia G_ia G_pg h_pg
    # gpci1 = greenp @ ci1.T
    # ci1_green = gpci1 @ green
    # ci1gp = ci1 @ greenp.T
    hgpg = oe.contract("pq,pa,iq->ia", h1, greenp, green_occ, backend="jax")
    # e1_1_2 = -2 * oe.contract("ij,ij->", h1, ci1_green, backend="jax")
    e1_1_2_orb = -2 * oe.contract("ia,ka,ik->", ci1, hgpg, prj, backend="jax")
    e1_1_orb = e1_1_1_orb + e1_1_2_orb

    # double excitations
    ggg = oe.contract("ia,ia->jb", ci2, green_occ, backend="jax")
    ci2g_c = oe.contract("iajb,ia->jb", ci2, green_occ, backend="jax")
    ci2g_e = oe.contract("iajb,ib->ja", ci2, green_occ, backend="jax")
    ci2_green_c = (greenp @ ci2g_c.T) @ green
    ci2_green_e = (greenp @ ci2g_e.T) @ green
    ci2_green = 2 * ci2_green_c - ci2_green_e
    ci2g = 2 * ci2g_c - ci2g_e
    gci2g = oe.contract("qu,qu->", ci2g, green_occ, backend="jax")
    e1_2_1 = 2 * hg * gci2g
    e1_2_2 = -2 * oe.contract("ij,ij->", h1, ci2_green, backend="jax")
    e1_2 = e1_2_1 + e1_2_2
    e1 = e1_0 + e1_1 + e1_2

    # two body energy
    # ref
    lg = oe.contract("gpj,pj->g", rot_chol, green, backend="jax")
    # lg1 = jnp.einsum("gpj,pk->gjk", rot_chol, green, optimize="optimal")
    lg1 = oe.contract("gpj,qj->gpq", rot_chol, green, backend="jax")
    e2_0_1 = 2 * lg @ lg
    e2_0_2 = -jnp.sum(vmap(lambda x: x * x.T)(lg1))
    e2_0 = e2_0_1 + e2_0_2

    # single excitations
    e2_1_1 = 2 * e2_0 * ci1g
    lci1g = oe.contract("gij,ij->g", chol, ci1_green, backend="jax")
    e2_1_2 = -2 * (lci1g @ lg)

    ci1g1 = ci1 @ green_occ.T
    # e2_1_3 = jnp.einsum("gpq,gpq->", glgpci1, lg1, optimize="optimal")
    e2_1_3_1 = oe.contract("gpq,gqr,rp->", lg1, lg1, ci1g1, backend="jax")
    lci1g = oe.contract("gip,qi->gpq", ham_data["lci1"], green, backend="jax")
    e2_1_3_2 = -oe.contract("gpq,gqp->", lci1g, lg1, backend="jax")
    e2_1_3 = e2_1_3_1 + e2_1_3_2
    e2_1 = e2_1_1 + 2 * (e2_1_2 + e2_1_3)

    # double excitations
    e2_2_1 = e2_0 * gci2g
    lci2g = oe.contract("gij,ij->g", chol, ci2_green, backend="jax")
    e2_2_2_1 = -lci2g @ lg

    # lci2g1 = jnp.einsum("gij,jk->gik", chol, ci2_green, optimize="optimal")
    # lci2_green = jnp.einsum("gpi,ji->gpj", rot_chol, ci2_green, optimize="optimal")
    # e2_2_2_2 = 0.5 * jnp.einsum("gpi,gpi->", gl, lci2_green, optimize="optimal")
    def scanned_fun(carry, x):
        chol_i, rot_chol_i = x
        gl_i = oe.contract("pj,ji->pi", green, chol_i, backend="jax")
        lci2_green_i = oe.contract(
            "pi,ji->pj", rot_chol_i, ci2_green, backend="jax"
        )
        carry[0] += 0.5 * oe.contract(
            "pi,pi->", gl_i, lci2_green_i, backend="jax"
        )
        glgp_i = oe.contract("pi,it->pt", gl_i, greenp, backend="jax")
        l2ci2_1 = oe.contract(
            "pt,qu,ptqu->",
            glgp_i,
            glgp_i,
            ci2,
            backend="jax"
        )
        l2ci2_2 = oe.contract(
            "pu,qt,ptqu->",
            glgp_i,
            glgp_i,
            ci2,
            backend="jax"
        )
        carry[1] += 2 * l2ci2_1 - l2ci2_2
        return carry, 0.0

    [e2_2_2_2, e2_2_3], _ = lax.scan(scanned_fun, [0.0, 0.0], (chol, rot_chol))
    e2_2_2 = 4 * (e2_2_2_1 + e2_2_2_2)

    e2_2 = e2_2_1 + e2_2_2 + e2_2_3

    e2 = e2_0 + e2_1 + e2_2

    # overlap
    overlap_1 = 2 * ci1g
    overlap_2 = gci2g
    overlap = 1.0 + overlap_1 + overlap_2
    return (e1 + e2) / overlap + e0