In [1]:
from functools import partial
from pyscf import gto, scf
import numpy as np

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

a = 1.05835 # 2aB
nH = 4
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="sto6g", verbose=4)
mol.build()

mf = scf.RHF(mol)#.density_fit()
e = mf.kernel()
# #cc
# mycc = cc.CCSD(mf)
# eccsd = mycc.kernel()[0]


System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Mon Sep 29 19:22:16 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 4
[INPUT] num. electrons = 4
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.00

In [2]:
from ad_afqmc.lno_afqmc import lno_maker, afqmc_maker, lnoafqmc_runner
options = {'n_eql': 4,
           'n_prop_steps': 10,
            'n_ene_blocks': 1,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 3,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.005,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc.lno.cc import LNOCCSD
from pyscf.lib import logger
import sys
from ad_afqmc.lno.base import lno
import numpy as np

In [3]:
mf_cc = mf
thresh = 1e-3
# chol_cut = 1e-6
frozen = 0
run_frg_list = None
lo_type = 'boys'
no_type = 'ie' # cim
frag_lolist = '1o'

from pyscf.cc.ccsd import CCSD
from pyscf.cc.uccsd import UCCSD
if isinstance(mf_cc, (CCSD, UCCSD)):
    mf = mf_cc._scf
else:
    mf = mf_cc

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
lno_cc.force_outcore_ao2mo = True

frag_atmlist = lno_cc.frag_atmlist
s1e = lno_cc._scf.get_ovlp()

orbloc = lno_cc.get_lo(lo_type=lo_type)
frag_lolist = [[i] for i in range(orbloc.shape[1])]
nfrag = len(frag_lolist)
frag_nonvlist = [[None,None]] * nfrag

frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100000*nfrag)

# for ifrag in run_frg_list:
for ifrag in [0]:
    print(f'########### running fragment {ifrag+1} ##########')
    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, can_orbfrag \
        = lno_maker.make_lno(lno_cc,orbfragloc,THRESH_INTERNAL,thresh_pno)
    
    mol = mf.mol
    nocc = mol.nelectron // 2 
    nao = mol.nao
    actfrag = np.array([i for i in range(nao) if i not in frzfrag])
    frzocc = np.array([i for i in range(nocc) if i in frzfrag])
    actocc = np.array([i for i in range(nocc) if i in actfrag])
    actvir = np.array([i for i in range(nocc,nao) if i in actfrag])
    print(f'# active orbitals: {actfrag}')
    print(f'# active occupied orbitals: {actocc}')
    print(f'# active virtual orbitals: {actvir}')
    print(f'# frozen orbitals: {frzfrag}')
    s1e = mf.get_ovlp()
    prjlo = orbfragloc.T@s1e@orbfrag[:,actocc]
    
    if options["trial"] == "cisd":
        ecorr_ccsd,t1,t2 \
            = lno_maker.lno_cc_solver(mf,orbfrag,orbfragloc,frozen=frzfrag)
        print(f'# lno-ccsd correlation energy: {ecorr_ccsd}')
        
        nelec_act = len(actocc)*2
        norb_act = len(actfrag)
        print(f'# number of active electrons: {nelec_act}')
        print(f'# number of active orbitals: {norb_act}')
        print(f'# number of frozen orbitals: {len(frzfrag)}')

        ci1 = np.array(t1)        
        ci2 = t2 + np.einsum("ia,jb->ijab", ci1, ci1)
        ci2 = ci2.transpose(0, 2, 1, 3)
    else: ci1 = ci2 = None

    options["seed"] = seeds[ifrag]
    lnoafqmc_runner.prep_lnoafqmc_file(
            mf,orbfrag,options,
            norb_act=norb_act,nelec_act=nelec_act,
            prjlo=prjlo,norb_frozen=frzfrag,
            ci1=ci1,ci2=ci2,chol_cut=1e-6)
    

lo_type = boys


******** <class 'pyscf.lo.boys.Boys'> ********
conv_tol = 1e-06
conv_tol_grad = None
max_cycle = 100
max_stepsize = 0.01
max_iters = 20
kf_interval = 5
kf_trust_region = 5
ah_start_tol = 1000000000.0
ah_start_cycle = 1
ah_level_shift = 0
ah_conv_tol = 1e-12
ah_lindep = 1e-14
ah_max_cycle = 40
ah_trust_region = 3
init_guess = atomic
Set conv_tol_grad to 0.000316228
macro= 1  f(x)= 13.847644346723  delta_f= 13.8476  |g|= 0.0288916  1 KF 1 Hx
macro= 2  f(x)= 13.847644346723  delta_f= 0  |g|= 1.25198e-07  1 KF 1 Hx
macro X = 2  f(x)= 13.847644346723  |g|= 1.25198e-07  4 intor 2 KF 2 Hx
lno thresh [0.01, 0.001]
########### running fragment 1 ##########
Using true 4-index integrals
Using true 4-index integrals
# active orbitals: [1 2 3]
# active occupied orbitals: [1]
# active virtual orbitals: [2 3]
# frozen orbitals: [0]
Init t2, MP2 energy = -2.09908869027016  E_corr(MP2) -0.0103943439483701

******** <class 'pyscf.cc.ccsd.CCSD'> ********
CC2 = 0
CCSD nocc = 1, nmo = 3
fr

Overwritten attributes  ao2mo  of <class 'pyscf.cc.ccsd.CCSD'>


cycle = 4  E_corr(CCSD) = -0.0138399574068676  dE = 4.57204331e-06  norm(t1,t2) = 0.000251567
cycle = 5  E_corr(CCSD) = -0.0138399859926739  dE = -2.85858063e-08  norm(t1,t2) = 7.48975e-06
CCSD converged
E(CCSD) = -2.102534332314466  E_corr = -0.01383998599267393
# lno-ccsd correlation energy: -0.01383999265557911
# number of active electrons: 2
# number of active orbitals: 3
# number of frozen orbitals: 1
# Generating Cholesky Integrals
# frozen orbitals: [0]
# local active orbitals: [1 2 3]
# local active space size: 3
# Finished calculating Cholesky integrals

# Size of the correlation space
# Number of electrons: (1, 1)
# Number of basis functions: 3
# Number of Cholesky vectors: 6



In [5]:
from ad_afqmc import config, wavefunctions
from jax import numpy as jnp
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
    lnoafqmc_runner.prep_lnoafqmc_run())

if options["use_gpu"]:
    config.afqmc_config["use_gpu"] = True

config.setup_jax()
MPI = config.setup_comm()

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

seed = options["seed"]
propagator = prop
init_walkers = None
ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 3
# nelec: (1, 1)
# nchol: 6
#
# n_eql: 4
# n_prop_steps: 10
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 3
# seed: 70495
# walker_type: rhf
# trial: cisd
# dt: 0.005
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# free_projection: False
# n_batch: 1
# ene0: 0
#
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [8]:
print(prop_data["e_estimate"])
print(mf.e_tot+ecorr_ccsd)

-2.1025343315221208
-2.10253433897737


In [9]:
prop_data, (blk_en,blk_wt,blk_hf_orb_en,blk_cc_orb_en) \
    = afqmc_maker.block_orb(prop_data,ham_data,prop,trial,wave_data,sampler)
# propagate_phaseless_orb(
#         ham_data,prop,prop_data,trial,wave_data,sampler)

In [10]:
print(blk_en)

-2.102534329869551


In [13]:
prop_data['walkers'][0]

Array([[-0.95344139-0.25599746j],
       [-0.0452224 +0.14403511j],
       [-0.01775788+0.04804474j]], dtype=complex128)

In [None]:
afqmc_maker._calc_olp_ratio_restricted(prop_data['walkers'][0],wave_data)

Array(0.99870124+0.00065381j, dtype=complex128)

In [18]:
from jax import vmap
def _frg_hf_cr(rot_h1, rot_chol, walker, trial, wave_data):
    '''hf orbital correlation energy multiplies the overlap ratio'''
    m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc = rot_h1.shape[0]
    _calc_green = wavefunctions.rhf(
        trial.norb,trial.nelec,n_batch=trial.n_batch)._calc_green
    green_walker = _calc_green(walker, wave_data)
    f = jnp.einsum('gij,jk->gik', rot_chol[:,:nocc,nocc:],
                    green_walker.T[nocc:,:nocc], optimize='optimal')
    c = vmap(jnp.trace)(f)
    eneo2Jt = jnp.einsum('Gxk,xk,G->',f,m,c)*2 
    eneo2ext = jnp.einsum('Gxy,Gyk,xk->',f,f,m)
    hf_orb_en = eneo2Jt - eneo2ext
    olp_ratio = afqmc_maker._calc_olp_ratio_restricted(walker,wave_data)
    hf_orb_cr = olp_ratio*hf_orb_en
    return hf_orb_cr, hf_orb_en

In [None]:
hf_orb_cr, hf_orb_en \
    =_frg_hf_cr(
        ham_data['rot_h1'],ham_data['rot_chol'],prop_data['walkers'][0],trial,wave_data)

In [21]:
print(jnp.real(hf_orb_cr), jnp.real(hf_orb_en))
print(jnp.abs(hf_orb_cr), jnp.abs(hf_orb_en))

-0.00257410273107042 -0.002577611106114524
0.0025859362497955727 0.002589298583033388


In [None]:
def block_orb(prop_data: dict,
              ham_data: dict,
              prop: propagation.propagator,
              trial: wavefunctions,
              wave_data: dict,
              sampler: sampling.sampler):
        """Block scan function. Propagation and orbital_i energy calculation."""
        prop_data["key"], subkey = random.split(prop_data["key"])
        fields = random.normal(
            subkey,
            shape=(
                sampler.n_prop_steps,
                prop.n_walkers,
                ham_data["chol"].shape[0],
            ),
        )
        # propgate n_prop_steps x dt
        _step_scan_wrapper = lambda x, y: sampler._step_scan(
            x, y, ham_data, prop, trial, wave_data
        )
        prop_data, _ = lax.scan(_step_scan_wrapper, prop_data, fields)
        prop_data["n_killed_walkers"] += prop_data["weights"].size - jnp.count_nonzero(
            prop_data["weights"]
        )
        prop_data = prop.orthonormalize_walkers(prop_data)
        prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)

        hf_orb_cr, hf_orb_en \
            = frg_hf_cr(prop_data["walkers"],ham_data,wave_data,trial)
        cc_orb_cr \
            = frg_ci_cr(prop_data["walkers"], ham_data, wave_data, trial,1e-5)
        energy_samples = jnp.real(
            trial.calc_energy(prop_data["walkers"], ham_data, wave_data)
        )
        energy_samples = jnp.where(
            jnp.abs(energy_samples - prop_data["e_estimate"]) > jnp.sqrt(2.0 / prop.dt),
            prop_data["e_estimate"],
            energy_samples,
        )

        blk_wt = jnp.sum(prop_data["weights"])
        blk_en = jnp.sum(energy_samples * prop_data["weights"]) / blk_wt
        blk_hf_orb_en = jnp.sum(hf_orb_en * prop_data["weights"]) / blk_wt
        blk_hf_orb_cr = jnp.sum(hf_orb_cr * prop_data["weights"]) / blk_wt
        blk_cc_orb_cr = jnp.sum(cc_orb_cr * prop_data["weights"]) / blk_wt
        blk_cc_orb_en = blk_hf_orb_cr+blk_cc_orb_cr
        
        prop_data["pop_control_ene_shift"] = (
            0.9 * prop_data["pop_control_ene_shift"] + 0.1 * blk_en
        )

        return prop_data,(blk_en,blk_wt,blk_hf_orb_en,blk_cc_orb_en)