In [1]:
from functools import partial
from pyscf import gto, scf

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

# H chain
# a = 1
# nH = 4
# atoms = ""
# for i in range(nH):
#     atoms += f"H {i*a} 0 0 \n"

# water dimer s22
atoms = '''
    O  -1.551007  -0.114520   0.000000
    H  -1.934259   0.762503   0.000000
    H  -0.599677   0.040712   0.000000
    O   1.350625   0.111469   0.000000
    H   1.680398  -0.373741  -0.758561
    H   1.680398  -0.373741   0.758561
    '''

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mf = scf.RHF(mol).density_fit()
mf.kernel()

# mf2 = scf.RHF(mol)#.density_fit()
# mf2.kernel()

# cc
# mycc = cc.CCSD(mf)
# mycc.kernel()
# et = mycc.ccsd_t()

# mycc = cc.CCSD(mf2)
# # mycc.kernel()

# print(f"ccsd energy is {mycc.e_tot}")
# print(f"ccsd_t energy is {mycc.e_tot+et}")

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#1882-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Fri Jul 18 16:37:00 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 20
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 O     -1.551007000000  -0.114520000000   0.000000000000 AA   -2.930978447283  -0.216411435785   0.0

-152.06249064692815

In [4]:
from ad_afqmc.lno.cc import LNOCCSD
from pyscf.lib import logger
import sys
from ad_afqmc.lno.base import lno
from ad_afqmc.lno_ccsd import lno_ccsd
import numpy as np
from functools import reduce
log = logger.Logger(sys.stdout, 6)

options = {
    "n_eql": 4,
    "n_ene_blocks": 1,
    "n_sr_blocks": 20,
    "n_blocks": 20,
    "n_walkers": 5,
    "seed": 98,
    "trial": "cisd",
    "walker_type": "rhf",
    "dt":0.005,
    "ene0": 0,
}

thresh = 1e-6
chol_cut = 1e-6
frozen = 2
run_frg_list = None
lo_type = 'pm'
no_type = 'ie' # cim
frag_lolist = '1o'

_fdot = np.dot
fdot = lambda *args: reduce(_fdot, args)

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
# lno_cc.ccsd_t = True
lno_cc.force_outcore_ao2mo = True

frag_atmlist = lno_cc.frag_atmlist
s1e = lno_cc._scf.get_ovlp()


log.info('no_type = %s', no_type)

# LO construction
orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
orbactocc = lno_cc.split_mo()[1] # non-localized active occ
m = fdot(orbloc.T, s1e, orbactocc)
lospanerr = abs(fdot(m.T, m) - np.eye(m.shape[1])).max()
if lospanerr > 1e-10:
    log.error('LOs do not fully span the occupied space! '
                'Max|<occ|LO><LO|occ>| = %e', lospanerr)
    raise RuntimeError

# check Span(LO) == Span(occ)
occspanerr = abs(fdot(m, m.T) - np.eye(m.shape[0])).max()
if occspanerr < 1e-10:
    log.info('LOs span exactly the occupied space.')
    if no_type not in ['ir','ie']:
        log.error('"no_type" must be "ir" or "ie".')
        raise ValueError
else:
    log.info('LOs span occupied space plus some virtual space.')

# LO assignment to fragments

if frag_lolist == '1o':
    log.info('Using single-LO fragment') # this is what we use, every active local occ labels a fragment
    frag_lolist = [[i] for i in range(orbloc.shape[1])]
else: print('Only support single LO fragment!')
nfrag = len(frag_lolist)
frag_nonvlist = lno_cc.frag_nonvlist

# dump info
log.info('nfrag = %d  nlo = %d', nfrag, orbloc.shape[1])
log.info('frag_atmlist = %s', frag_atmlist)
log.info('frag_lolist = %s', frag_lolist)
log.info('frag_nonvlist = %s', frag_nonvlist)

if not (no_type[0] in 'rei' and no_type[1] in 'rei'):
    log.warn('Input no_type "%s" is invalid.', no_type)
    raise ValueError

if frag_nonvlist is None: frag_nonvlist = [[None,None]] * nfrag

eris = lno_cc.ao2mo()
frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100000*nfrag)

# for ifrag in run_frg_list:
for ifrag in [1]:
    print(f'########### running fragment {ifrag+1} ##########')
    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, can_orbfrag = lno.make_fpno1(lno_cc, eris, orbfragloc, no_type,
                                                THRESH_INTERNAL, thresh_pno,
                                                frozen_mask=frozen_mask,
                                                frag_target_nocc=frag_target_nocc,
                                                frag_target_nvir=frag_target_nvir,
                                                canonicalize=False)
    

    ecorr_ccsd,t1,t2,prjlo,nactocc,nactvir,maskact,maskocc \
        = lno_ccsd.cc_impurity_solve(
            mf,orbfrag,orbfragloc,frozen=frzfrag,eris=eris,log=log)
    
    nelec_act = nactocc*2
    norb_act = nactocc+nactvir
    print(f'# lno-ccsd correlation energy: {ecorr_ccsd}')
    print(f'# number of active electrons: {nelec_act}')
    print(f'# number of active orbitals: {norb_act}')
    print(f'# number of frozen orbitals: {len(frzfrag)}')


    options["seed"] = seeds[ifrag]
    lno_ccsd.prep_lno_amp_chol_file(
        mf,orbfrag,options,
        norb_act=(nactocc+nactvir),nelec_act=nactocc*2,
        prjlo=prjlo,norb_frozen=frzfrag,
        t1=t1,t2=t2,chol_cut=chol_cut
                )  
    

no_type = ie
lo_type = pm


LOs span exactly the occupied space.
Using single-LO fragment
nfrag = 8  nlo = 8
frag_atmlist = None
frag_lolist = [[0], [1], [2], [3], [4], [5], [6], [7]]
frag_nonvlist = None
Lov is saved to /tmp/xe2f97gk
lno thresh [9.999999999999999e-06, 1e-06]
########### running fragment 2 ##########

WARN: CCSD detected DF being used in the HF object. MO integrals are computed based on the DF 3-index tensors.
It's recommended to use dfccsd.CCSD for the DF-CCSD calculations

Init t2, MP2 energy = -152.367586940767  E_corr(MP2) -0.305096293838968

******** <class 'pyscf.cc.dfccsd.RCCSD'> ********
CC2 = 0
CCSD nocc = 7, nmo = 35
frozen orbitals [ 0  1  2 38 39 40 41 42 43 44 45 46 47]
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-05
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 4000 MB (current use 417 MB)
Init E_corr(RCCSD) = -0.305096293839516
cycle = 1  E_corr(RCCSD) = -0.307999846367845  dE = -0.00290355253  norm(t1,t2) = 0.0410041
cycle = 2  E_c

In [5]:
from ad_afqmc import config, wavefunctions
from jax import numpy as jnp
from jax import random, jit, vmap, jvp, lax
import jax

config.setup_jax()
MPI = config.setup_comm()

ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
    lno_ccsd.prep_lnoccsd_afqmc())

comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Process rank
size = comm.Get_size() 

seed = options["seed"]
propagator = prop
init_walkers = None
ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 35
# nelec: (7, 7)
# nchol: 232
#
# n_eql: 4
# n_ene_blocks: 1
# n_sr_blocks: 20
# n_blocks: 20
# n_walkers: 5
# seed: 551475
# trial: cisd
# walker_type: rhf
# dt: 0.005
# ene0: 0
# n_exp_terms: 6
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# free_projection: False
# n_batch: 1
# use_gpu: False
#


In [6]:
from ad_afqmc.lno_ccsd import lno_ccsd
walker = prop_data['walkers'][0]
# e_init = jnp.real(trial._calc_energy_restricted(walker,ham_data,wave_data))
ccsd_cr0,ccsd_cr1,ccsd_cr2,ccsd_cr = lno_ccsd._frg_ccsd_cr(walker,ham_data,wave_data,trial,1e-6)
# print('afqmc/ccsd init energy:',e_init)
# print('mean-field energy:',mf.e_tot)
# e_corr_init = e_init - mf.e_tot
print('lno-afqmc/ccsd init frg correlation energy:',ccsd_cr)
print('lno-ccsd init correlation energy:',ecorr_ccsd)
# -0.06902628862909294
# -0.06849834335942975

lno-afqmc/ccsd init frg correlation energy: -0.05381149472499322
lno-ccsd init correlation energy: -0.05381149431194333


In [8]:
prop_data,(energy,wt,
           hf_orb_cr,olp_ratio,
           ccsd_orb_cr0,ccsd_orb_cr1,ccsd_orb_cr2,
           ccsd_orb_cr,orb_cr)\
    = lno_ccsd.propagate_phaseless_orb(ham_data,prop,prop_data,trial,wave_data,sampler)

In [9]:
print(energy,wt)
print(hf_orb_cr,olp_ratio)
print(ccsd_orb_cr0,ccsd_orb_cr1,ccsd_orb_cr2,ccsd_orb_cr)
print(orb_cr)

-152.39566046710257 99.99068471987039
-0.04435036268421961 (0.9220578928295465-0.0011193421072858052j)
0.5468795913624183 -1.071288001043854 0.5143074233087206 -0.010100986372715313
-0.05445134905693493


In [10]:
print(ccsd_orb_cr0+ccsd_orb_cr1+ccsd_orb_cr2)
print(hf_orb_cr+ccsd_orb_cr0+ccsd_orb_cr1+ccsd_orb_cr2)

-0.01010098637271506
-0.05445134905693472


In [134]:
walker = prop_data['walkers'][0]
hf_cr = lno_ccsd._tot_hf_cr(ham_data['rot_h1'],ham_data['rot_chol'],walker,trial,wave_data)
ccsd_cr0,ccsd_cr1,ccsd_cr2,ccsd_cr = _tot_ccsd_cr(walker,ham_data,wave_data,trial,1e-5)
print(hf_cr)
print(ccsd_cr0,ccsd_cr1,ccsd_cr2,ccsd_cr)
print(hf_cr+ccsd_cr)

4.405561946234566
-0.009747164169830328
0.030705359520143017 -0.051134426375230684 -0.038552745120460793 -0.05898181197554846
-0.06872897614537879


In [131]:
e0, e1_0, e2_0, e_mf, ecorr1, ecorr2, ecorr \
    = _afqmc_ccsd_corr(walker,ham_data,wave_data,trial)
print(e0, e1_0, e2_0, e_mf)
print(ecorr1, ecorr2, ecorr)
print(ecorr+e_mf)

2.29310124732 -6.788713518225568 2.373358607505927 -2.1222536633996407
-0.03980199284071267 -0.05042937664142065 -0.09023136948213331
-2.2124850328817742


In [135]:
print(mf2.e_tot)

-2.112460698914566


In [26]:
# AFQMC energy: -2.18098 +/- 0.00002
# lno tot_afqmc_cr -0.068465 +/- 0.000043
# non-df mean-field energy is -2.112460698914565
-2.18098- -2.112460698914565

-0.06851930108543502