In [1]:
from functools import partial
from pyscf import gto, scf, cc

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

# H chain
a = 1
nH = 10
atoms = ""
for i in range(nH):
    atoms += f"H {i*a} 0 0 \n"

# # water dimer s22
# atoms = '''
#     O  -1.551007  -0.114520   0.000000
#     H  -1.934259   0.762503   0.000000
#     H  -0.599677   0.040712   0.000000
#     O   1.350625   0.111469   0.000000
#     H   1.680398  -0.373741  -0.758561
#     H   1.680398  -0.373741   0.758561
#     '''

mol = gto.M(atom=atoms, basis="sto6g", verbose=4)
mf = scf.RHF(mol).density_fit()
mf.kernel()

# mf2 = scf.RHF(mol)#.density_fit()
# mf2.kernel()

#cc
mycc = cc.CCSD(mf)
mycc.kernel()[0]
# et = mycc.ccsd_t()

# mycc = cc.CCSD(mf2)
# # mycc.kernel()

# print(f"ccsd energy is {mycc.e_tot}")
# print(f"ccsd_t energy is {mycc.e_tot+et}")

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#1882-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Tue Aug  5 12:29:49 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 10
[INPUT] num. electrons = 10
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.

-0.1673383950712719

In [3]:
from ad_afqmc.lno.base import lno
from pyscf.lib import logger
import numpy as np
from functools import reduce

_fdot = np.dot
fdot = lambda *args: reduce(_fdot, args)

In [None]:
from ad_afqmc.lno.cc import LNOCCSD
from pyscf.lib import logger
import sys
from ad_afqmc.lno.base import lno
from ad_afqmc.lno_ccsd import lno_ccsd
import numpy as np
from functools import reduce
log = logger.Logger(sys.stdout, 6)

options = {
    "n_eql": 4,
    "n_ene_blocks": 1,
    "n_sr_blocks": 20,
    "n_blocks": 20,
    "n_walkers": 5,
    "seed": 98,
    "trial": "cisd",
    "walker_type": "rhf",
    "dt":0.005,
    "ene0": 0,
}

thresh = 1e-4
# chol_cut = 1e-6
frozen = 0
run_frg_list = None
lo_type = 'pm'
no_type = 'ie' # cim
frag_lolist = '1o'

t1 = mycc.t1
t2 = mycc.t2
full_ci1 = np.array(t1)
full_ci2 = t2 + np.einsum("ia,jb->ijab", full_ci1, full_ci1)
full_ci2 = full_ci2.transpose(0, 2, 1, 3)
np.savez("amplitudes.npz", full_ci1=full_ci1, full_ci2=full_ci2)

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
# lno_cc.ccsd_t = True
lno_cc.force_outcore_ao2mo = True

frag_atmlist = lno_cc.frag_atmlist
s1e = lno_cc._scf.get_ovlp()


# log.info('no_type = %s', no_type)

# LO construction
orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
orbactocc = lno_cc.split_mo()[1] # non-localized active occ
m = fdot(orbloc.T, s1e, orbactocc)
lospanerr = abs(fdot(m.T, m) - np.eye(m.shape[1])).max()
if lospanerr > 1e-10:
    log.error('LOs do not fully span the occupied space! '
                'Max|<occ|LO><LO|occ>| = %e', lospanerr)
    raise RuntimeError

# check Span(LO) == Span(occ)
occspanerr = abs(fdot(m, m.T) - np.eye(m.shape[0])).max()
if occspanerr < 1e-10:
    log.info('LOs span exactly the occupied space.')
    if no_type not in ['ir','ie']:
        log.error('"no_type" must be "ir" or "ie".')
        raise ValueError
else:
    log.info('LOs span occupied space plus some virtual space.')

# LO assignment to fragments

if frag_lolist == '1o':
    log.info('Using single-LO fragment') # this is what we use, every active local occ labels a fragment
    frag_lolist = [[i] for i in range(orbloc.shape[1])]
else: print('Only support single LO fragment!')
nfrag = len(frag_lolist)
frag_nonvlist = lno_cc.frag_nonvlist

# dump info
log.info('nfrag = %d  nlo = %d', nfrag, orbloc.shape[1])
log.info('frag_atmlist = %s', frag_atmlist)
log.info('frag_lolist = %s', frag_lolist)
log.info('frag_nonvlist = %s', frag_nonvlist)

if not (no_type[0] in 'rei' and no_type[1] in 'rei'):
    log.warn('Input no_type "%s" is invalid.', no_type)
    raise ValueError

if frag_nonvlist is None: frag_nonvlist = [[None,None]] * nfrag

eris = lno_cc.ao2mo()
frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100000*nfrag)

# for ifrag in run_frg_list:
for ifrag in [0]:
    print(f'########### running fragment {ifrag+1} ##########')
    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, can_orbfrag = lno.make_fpno1(lno_cc, eris, orbfragloc, no_type,
                                                THRESH_INTERNAL, thresh_pno,
                                                frozen_mask=frozen_mask,
                                                frag_target_nocc=frag_target_nocc,
                                                frag_target_nvir=frag_target_nvir,
                                                canonicalize=False)
    
    full_cisd = True
    if full_cisd:
        content = dict(np.load("amplitudes.npz"))
        mol = mf.mol
        nocc = mol.nelectron // 2 
        nao = mol.nao
        actfrag = np.array([i for i in range(nao) if i not in frzfrag])
        frzocc = np.array([i for i in range(nocc) if i in frzfrag])
        #actocc = np.array([i for i in range(nocc) if i in actfrag])
        no2mo_prj = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag).T
        content["no2mo_frzocc"] = no2mo_prj[:, frzocc]
        content["no2mo_act"] = no2mo_prj[:, actfrag]
        np.savez("amplitudes.npz", **content)

    ecorr_ccsd,t1,t2,prjlo,nactocc,nactvir,maskact,maskocc \
        = lno_ccsd.cc_impurity_solve(
            mf,orbfrag,orbfragloc,frozen=frzfrag,eris=eris,log=log)
    
    nelec_act = nactocc*2
    norb_act = nactocc+nactvir
    print(f'# lno-ccsd correlation energy: {ecorr_ccsd}')
    print(f'# number of active electrons: {nelec_act}')
    print(f'# number of active orbitals: {norb_act}')
    print(f'# number of frozen orbitals: {len(frzfrag)}')


    options["seed"] = seeds[ifrag]
    lno_ccsd.prep_lno_amp_chol_file(
        mf,orbfrag,options,
        norb_act=(nactocc+nactvir),nelec_act=nactocc*2,
        prjlo=prjlo,norb_frozen=frzfrag,t1=t1,t2=t2,full_cisd=full_cisd)  
    

lo_type = pm
LOs span exactly the occupied space.
Using single-LO fragment
nfrag = 5  nlo = 5
frag_atmlist = None
frag_lolist = [[0], [1], [2], [3], [4]]
frag_nonvlist = None
Lov is saved to /tmp/de3pwd9t
lno thresh [0.001, 0.0001]
########### running fragment 1 ##########

WARN: CCSD detected DF being used in the HF object. MO integrals are computed based on the DF 3-index tensors.
It's recommended to use dfccsd.CCSD for the DF-CCSD calculations

Init t2, MP2 energy = -5.28169984672868  E_corr(MP2) -0.0309313967729368

******** <class 'pyscf.cc.dfccsd.RCCSD'> ********
CC2 = 0
CCSD nocc = 2, nmo = 6
frozen orbitals [0 1 2 9]
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-05
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 4000 MB (current use 258 MB)
Init E_corr(RCCSD) = -0.030931396772961
cycle = 1  E_corr(RCCSD) = -0.0364125454806054  dE = -0.00548114871  norm(t1,t2) = 0.0305595
cycle = 2  E_corr(RCCSD) = -0.038913244970123  dE = -0.002500

Overwritten attributes  ao2mo  of <class 'pyscf.cc.dfccsd.RCCSD'>


cycle = 5  E_corr(RCCSD) = -0.0405260096778854  dE = 2.14226615e-05  norm(t1,t2) = 0.000407755
cycle = 6  E_corr(RCCSD) = -0.0405247526343399  dE = 1.25704355e-06  norm(t1,t2) = 0.000107132
cycle = 7  E_corr(RCCSD) = -0.0405262780422874  dE = -1.52540795e-06  norm(t1,t2) = 3.03296e-05
cycle = 8  E_corr(RCCSD) = -0.0405262700810433  dE = 7.96124409e-09  norm(t1,t2) = 5.55459e-06
RCCSD converged
E(RCCSD) = -5.291294720036783  E_corr = -0.04052627008104327
# lno-ccsd correlation energy: -0.026643193849104952
# number of active electrons: 4
# number of active orbitals: 6
# number of frozen orbitals: 4
# Generating Cholesky Integrals
# frozen orbitals: [0 1 2 9]
# local active orbitals: [3 4 5 6 7 8]
# local active space size: 6
# using density fitting
# Decomposing ERI with DF
# chol shape: (90, 36)
# Finished calculating Cholesky integrals

# Size of the correlation space
# Number of electrons: (2, 2)
# Number of basis functions: 6
# Number of DF vectors: 90



In [15]:
mo = mf.mo_coeff
no = orbfrag
lo_occ = orbloc
dm1 = mf.make_rdm1(mo)
dm2 = mf.make_rdm2(no)

In [21]:
nocc = mol.nelectron // 2
mo_occ = mo[:,:nocc]
mo_vir = mo[:,nocc:]
no_occ = no[:,:nocc]
no_vir = no[:,nocc:]
m = fdot(no_vir.T, s1e, mo_vir)
lospanerr = abs(fdot(m.T, m) - np.eye(m.shape[1])).max()
print(lospanerr)
if lospanerr > 1e-10:
    log.error('LOs do not fully span the occupied space! '
                'Max|<occ|LO><LO|occ>| = %e', lospanerr)
    raise RuntimeError

3.5860203695392556e-14


In [7]:
mf.kernel(dm=dm2)



******** <class 'pyscf.df.df_jk.DFRHF'> ********
method = DFRHF
initial guess = minao
damping factor = 0
level_shift factor = 0
DIIS = <class 'pyscf.scf.diis.CDIIS'>
diis_start_cycle = 1
diis_space = 8
diis_damp = 0
SCF conv_tol = 1e-09
SCF conv_tol_grad = None
SCF max_cycles = 50
direct_scf = False
chkfile to save SCF result = /tmp/tmpw20bw0co
max_memory 4000 MB (current use 261 MB)
Set gradient conv threshold to 3.16228e-05
init E= -5.25076844995581
  HOMO = -0.27317661221491  LUMO = 0.162812857650495
cycle= 1 E= -5.25076844995581  delta_E=    0  |g|= 2.38e-08  |ddm|= 1.22e-07
  HOMO = -0.273176607440234  LUMO = 0.162812852424909
Extra cycle  E= -5.25076844995581  delta_E= -5.33e-15  |g|= 1.08e-08  |ddm|= 5.54e-08
converged SCF energy = -5.25076844995581


-5.250768449955812

In [26]:
s1e = mf.get_ovlp()
prj_mo2no = lno_ccsd.no2mo(mo,s1e,no).T

In [23]:
print(full_ci1.shape, full_ci2.shape)

(5, 5) (5, 5, 5, 5)


In [None]:
nocc = mol.nelectron // 2 
nao = mol.nao
print(nao)
actfrag = np.array([i for i in range(nao) if i not in frzfrag])
frzocc = np.array([i for i in range(nocc) if i in frzfrag])
actocc = np.array(
    [i for i in range(nocc) if i in actfrag])
actvir = np.array(
    [i for i in range(nocc,nao) if i in actfrag])
print(frzocc)
print(actocc)
print(actvir)

10
[0 1 2 9]
[3 4 5 6 7 8]
[0 1 2]
[3 4]
[5 6 7 8]


In [38]:
prj_oo = prj_mo2no[:nocc,:nocc]
prj_vv = prj_mo2no[nocc:,nocc:]
print(prj_oo.shape, prj_vv.shape)
prj_oo_act = prj_oo[actocc]
prj_vv_act = prj_vv[actvir]
print(prj_oo_act.shape, prj_vv_act.shape)

(5, 5) (5, 5)
(2, 5) (4, 5)


In [None]:
t1 = mycc.t1
t2 = mycc.t2
t1 = np.einsum("ji,jb,ab->ia",prj_oo_act.T,full_ci1,prj_vv_act)
ci2 = np.einsum("ki,lj,kcld,ac,bd->iajb",prj_oo_act.T,prj_oo_act.T,full_ci2,prj_vv_act,prj_vv_act)
# t1 = ci1
# t2 = ci2.transpose(0, 2, 1, 3)
# t2 -= np.einsum('ia,jb->ijab',t1,t1)

In [40]:
print(ci1.shape, ci2.shape)

(2, 4) (2, 4, 2, 4)


In [34]:
from pyscf.cc import CCSD
mcc = CCSD(mf,mo_coeff=no)
mcc.kernel(t1=t1, t2=t2)


******** <class 'pyscf.cc.dfccsd.RCCSD'> ********
CC2 = 0
CCSD nocc = 5, nmo = 10
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-05
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 4000 MB (current use 275 MB)
Init E_corr(RCCSD) = -0.167338402005305
cycle = 1  E_corr(RCCSD) = -0.167338385254197  dE = 1.67511078e-08  norm(t1,t2) = 6.62485e-07
RCCSD converged
E(RCCSD) = -5.418106835209934  E_corr = -0.1673383852541969


(-0.16733838525419692,
 array([[ 1.39819405e-08, -6.06413416e-03, -2.89459364e-07,
         -5.66297684e-03, -1.60331255e-03],
        [-4.31046070e-04, -3.51704052e-07,  4.96727169e-03,
         -3.26094427e-07, -9.01739444e-08],
        [-4.67171529e-03,  5.66653549e-09,  5.43123467e-03,
          2.12349032e-08,  3.14604479e-08],
        [ 5.00514652e-09, -4.95759938e-03,  1.65353294e-08,
         -1.30702726e-02,  1.99504430e-03],
        [-1.87984015e-03,  4.93313271e-09,  8.81412535e-03,
          1.35642826e-08,  1.29457628e-08]]),
 array([[[[-3.86377641e-02,  1.88085543e-06,  1.95575389e-02,
            5.12149096e-07,  4.65328008e-06],
          [ 1.88085543e-06, -2.00035876e-02, -2.41744010e-06,
           -1.20063347e-02, -1.84896857e-02],
          [ 1.95575389e-02, -2.41744010e-06, -2.11011667e-02,
           -1.33848373e-06, -2.86591994e-06],
          [ 5.12149096e-07, -1.20063347e-02, -1.33848373e-06,
           -8.47579136e-03, -5.34276749e-03],
          [ 4.65328008e

In [19]:
from ad_afqmc import config, wavefunctions
from jax import numpy as jnp
from jax import random, jit, vmap, jvp, lax
from ad_afqmc.lno_ccsd import lno_ccsd
import jax

config.setup_jax()
MPI = config.setup_comm()

ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
    lno_ccsd.prep_lnoccsd_afqmc(full_cisd=True))

comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Process rank
size = comm.Get_size() 

seed = options["seed"]
propagator = prop
init_walkers = None
ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

init_walker = prop_data["walkers"][0]

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 6
# nelec: (2, 2)
# nchol: 90
#
# n_eql: 4
# n_ene_blocks: 1
# n_sr_blocks: 20
# n_blocks: 20
# n_walkers: 5
# seed: 365694
# trial: cisd
# walker_type: rhf
# dt: 0.005
# ene0: 0
# n_exp_terms: 6
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# free_projection: False
# n_batch: 1
# use_gpu: False
#


In [13]:
no_coeff = orbfrag
mo_coeff = mf.mo_coeff
nocc = mol.nelectron // 2 
nao = mol.nao
print(nao)
actfrag = np.array(
    [i for i in range(nao) if i not in frzfrag])
print(frzfrag)
print(actfrag)
frzocc = np.array(
    [i for i in range(nocc) if i in frzfrag])
actocc = np.array(
    [i for i in range(nocc) if i in actfrag])
print(frzocc)
print(actocc)

10
[0 1 2 9]
[3 4 5 6 7 8]
[0 1 2]
[3 4]


In [23]:
s1e = mf.get_ovlp()
p_frzocc = mo_coeff.T@s1e@no_coeff[:,frzocc]
p_act = mo_coeff.T@s1e@no_coeff[:,actfrag]
print(p_frzocc)
print(p_act)

[[ 3.51133733e-05 -5.95482343e-01  5.20188284e-01]
 [-9.03461486e-01 -5.27799098e-05  1.11429654e-06]
 [ 4.51715920e-05 -7.79664163e-01 -1.90451689e-01]
 [ 4.28669266e-01  2.47545116e-05  1.54452614e-06]
 [-1.11920902e-05  1.93712593e-01  8.32545676e-01]
 [-3.68144236e-16 -1.66533454e-16 -1.11022302e-16]
 [-1.94289029e-16 -8.05951849e-17  1.99509399e-16]
 [-1.08637058e-16  1.11022302e-16 -4.02455846e-16]
 [-1.11022302e-16 -5.16419047e-17 -1.62898809e-16]
 [ 1.57615891e-17  1.24900090e-16 -2.77555756e-16]]
[[-5.05048718e-07  6.12213138e-01  3.12525358e-17 -7.04889644e-17
   3.42468757e-16  3.90406953e-17]
 [ 4.28669267e-01 -1.12818678e-07  1.32899461e-16 -3.97558684e-16
   1.60908874e-16 -2.14810867e-16]
 [-4.54641866e-07 -5.96533272e-01 -1.07350339e-16 -2.36045603e-16
  -1.52175214e-16 -1.18410953e-16]
 [ 9.03461488e-01 -1.07528190e-06  3.83207963e-16  5.49564569e-17
  -6.70501768e-17  1.78677753e-17]
 [-2.03826829e-06 -5.18982782e-01 -1.85137037e-16  8.51262453e-17
   9.03317439e-17 -

In [29]:
walker_act = p_act@init_walker
walker_new = jnp.hstack((p_frzocc,walker_act))
print(walker_new.shape)

(10, 5)


In [88]:
def no2mo(mo_coeff,s1e,no_coeff):
    prj = mo_coeff.T@s1e@no_coeff
    return prj

def prj_walker(projector,walker,frzocc,actfrag):

    p_frzocc = projector[:,frzocc]
    p_act = projector[:,actfrag]
    walker_act = p_act@walker
    walker_new = jnp.hstack((p_frzocc,walker_act))

    return walker_new


In [89]:
def cisd_walker_overlap(walker,ci1,ci2):
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return (1.0 + 2 * o1 + o2) * o0

def cisd_walker_overlap_ratio(walker,ci1,ci2):
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    # o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return 1/(1.0 + 2 * o1 + o2)

def full_cisd_walker_overlap(walker,ci1,ci2):
    
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return (1.0 + 2 * o1 + o2) * o0

In [21]:
walker = prop_data["walkers"][0]
p_frzocc,p_act = wave_data["no2mo_frzocc"], wave_data["no2mo_act"]
walker_t = lno_ccsd.prj_walker(p_frzocc,p_act,walker)


In [None]:
def _frg_mod_ccsd_olp2(walker: jax.Array, wave_data: dict) -> complex:
    '''
    <psi_ccsd|walker>=<psi_0|walker>+C_ia^*G_ia+C_iajb^*(G_iaG_jb-G_ibG_ja)
    modified CCSD overlap returns the second and the third term
    that is, the overlap of the walker with the CCSD wavefunction
    without the hartree-fock part
    and skip one sum over the occ
    '''
    # prjlo = wave_data["prjlo"].reshape(walker.shape[1])
    m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc, ci1, ci2 = walker.shape[1], wave_data["full_ci1"], wave_data["full_ci2"]
    gf = (walker.dot(jnp.linalg.inv(walker[: walker.shape[1], :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    
    o1 = jnp.einsum("ia,ka,ik->", ci1, gf[:, nocc:],m)
    o2 = 2 * jnp.einsum("iajb,ka,jb,ik->", ci2, gf[:, nocc:], gf[:, nocc:],m) \
        - jnp.einsum("iajb,kb,ja,ik->", ci2, gf[:, nocc:], gf[:, nocc:],m)
    olp = (2*o1+o2)*o0
    # olp_i = jnp.einsum("i,i->", olp, pick_i)
    return olp

In [90]:
t1 = mycc.t1
t2 = mycc.t2
ci2 = t2 + jnp.einsum("ia,jb->ijab", jnp.array(t1), jnp.array(t1))
ci2 = ci2.transpose(0, 2, 1, 3)
ci1 = jnp.array(t1)
np.savez("amplitudes.npz", full_ci1=ci1, full_ci2=ci2)

In [91]:
amplitudes = np.load("amplitudes.npz")
full_ci1 = jnp.array(amplitudes["full_ci1"])
full_ci2 = jnp.array(amplitudes["full_ci2"])
wave_data.update({"full_ci1": ci1, "full_ci2": ci2})

In [94]:
walker = prop_data["walkers"][0]
no2mo_prj = no2mo(mo_coeff, s1e, no_coeff)
walker_p = prj_walker(no2mo_prj,walker,frzocc,actfrag)
olp = cisd_walker_overlap(walker_p, ci1, ci2)
print(olp)

(0.5017825283861721-0.7473863577351859j)


In [95]:
prop_data, (blk_e, blk_wt)\
    = sampler._block_scan(prop_data,None,ham_data,prop,trial,wave_data)
print(blk_e,blk_wt)

-5.291230184312794 5.00300458169368


In [None]:
walker = prop_data["walkers"][0]
walker_p = prj_walker(no2mo_prj,walker,frzocc,actfrag)
olp = cisd_walker_overlap(walker_p, ci1, ci2)
olp_r = cisd_walker_overlap_ratio(walker_p, ci1, ci2)
print(olp)
print(olp_r)

(0.5017825283861721-0.7473863577351859j)
(0.9868800998319901-0.0031054239545766035j)


In [77]:
lno_olp_r = lno_ccsd._calc_olp_ratio_restricted(walker,wave_data)
print(lno_olp_r)

(0.9894610408169681+0.0013870486435967869j)


In [61]:
jnp.linalg.det(walker[: walker.shape[1], :]) ** 2

Array(0.84799905-0.28465324j, dtype=complex128)

In [58]:
trial._calc_overlap_restricted(walker,wave_data)

Array(0.85262757-0.29056292j, dtype=complex128)

In [6]:
from ad_afqmc.lno_ccsd import lno_ccsd
walker = prop_data['walkers'][0]
# e_init = jnp.real(trial._calc_energy_restricted(walker,ham_data,wave_data))
ccsd_cr0,ccsd_cr1,ccsd_cr2,ccsd_cr = lno_ccsd._frg_ccsd_cr(walker,ham_data,wave_data,trial,1e-6)
# print('afqmc/ccsd init energy:',e_init)
# print('mean-field energy:',mf.e_tot)
# e_corr_init = e_init - mf.e_tot
print('lno-afqmc/ccsd init frg correlation energy:',ccsd_cr)
print('lno-ccsd init correlation energy:',ecorr_ccsd)
# -0.06902628862909294
# -0.06849834335942975

lno-afqmc/ccsd init frg correlation energy: -0.05381149472499322
lno-ccsd init correlation energy: -0.05381149431194333


In [8]:
prop_data,(energy,wt,
           hf_orb_cr,olp_ratio,
           ccsd_orb_cr0,ccsd_orb_cr1,ccsd_orb_cr2,
           ccsd_orb_cr,orb_cr)\
    = lno_ccsd.propagate_phaseless_orb(ham_data,prop,prop_data,trial,wave_data,sampler)

In [9]:
print(energy,wt)
print(hf_orb_cr,olp_ratio)
print(ccsd_orb_cr0,ccsd_orb_cr1,ccsd_orb_cr2,ccsd_orb_cr)
print(orb_cr)

-152.39566046710257 99.99068471987039
-0.04435036268421961 (0.9220578928295465-0.0011193421072858052j)
0.5468795913624183 -1.071288001043854 0.5143074233087206 -0.010100986372715313
-0.05445134905693493


In [10]:
print(ccsd_orb_cr0+ccsd_orb_cr1+ccsd_orb_cr2)
print(hf_orb_cr+ccsd_orb_cr0+ccsd_orb_cr1+ccsd_orb_cr2)

-0.01010098637271506
-0.05445134905693472


In [None]:
## build <full_ccsd|walker> ###