In [2]:
from functools import partial
from pyscf import gto, scf, cc, ci
import numpy as np

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

a = 1 # 2aB
nH = 6
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="sto6g", verbose=4)
mol.build()

mf = scf.RHF(mol)
e = mf.kernel()

myci = ci.CISD(mf)
e, v = myci.kernel()


System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Mon Sep 22 11:02:11 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 6
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.00

Set gradient conv threshold to 3.16228e-05
Initial guess from minao.
init E= -2.43456844968186
  HOMO = -0.206186844053804  LUMO = 0.134147801579113
cycle= 1 E= -3.1487806963508  delta_E= -0.714  |g|= 0.0984  |ddm|= 1.71
  HOMO = -0.32236405088479  LUMO = 0.205265687081063
cycle= 2 E= -3.15516555755206  delta_E= -0.00638  |g|= 0.0314  |ddm|= 0.222
  HOMO = -0.32657231270121  LUMO = 0.219132025413817
cycle= 3 E= -3.15597595168783  delta_E= -0.00081  |g|= 0.00575  |ddm|= 0.0797
  HOMO = -0.3295049622818  LUMO = 0.219324890545481
cycle= 4 E= -3.15600091849169  delta_E= -2.5e-05  |g|= 0.000147  |ddm|= 0.018
  HOMO = -0.329429609111352  LUMO = 0.219283193175081
cycle= 5 E= -3.15600092949307  delta_E= -1.1e-08  |g|= 8.09e-06  |ddm|= 0.000293
  HOMO = -0.329440909371799  LUMO = 0.219288316614678
cycle= 6 E= -3.1560009295473  delta_E= -5.42e-11  |g|= 1.06e-07  |ddm|= 2.49e-05
  HOMO = -0.329441053062948  LUMO = 0.219288365072662
Extra cycle  E= -3.1560009295473  delta_E= -8.88e-15  |g|= 3.42e-

In [8]:
from ad_afqmc.lno_ccsd import lno_ccsd
options = {'n_eql': 4,
           'n_prop_steps': 10,
            'n_ene_blocks': 1,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 10,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.005,
            'free_projection':True,
            'ad_mode':None,
            'use_gpu': False,
            }

In [4]:
from ad_afqmc.lno.cc import LNOCCSD
import sys
from ad_afqmc.lno.base import lno
from ad_afqmc.lno_ccsd import lno_ccsd
from pyscf.lib import logger
from pyscf import lib
log = logger.Logger(sys.stdout, 6)

In [9]:
mfcc = myci
thresh = 1e-3
frozen = None
run_frg_list = None
use_df_vecs = False
chol_cut = 1e-7
lo_type = 'boys'
no_type = 'ie'
frag_lolist = '1o'

from pyscf.cc.ccsd import CCSD
# from pyscf.cc.uccsd import UCCSD
from pyscf.ci.cisd import CISD
if isinstance(mfcc, (CCSD, CISD)):
    full_cisd = True
    mf = mfcc._scf
else:
    full_cisd = False
    mf = mfcc

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
lno_cc.force_outcore_ao2mo = True

s1e = lno_cc._scf.get_ovlp()
orbactocc = lno_cc.split_mo()[1] # non-localized active occ
orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
m = orbloc.T @ s1e @ orbactocc
lospanerr = abs((m.T @ m) - np.eye(m.shape[1])).max()
if lospanerr > 1e-10:
    log.error('LOs do not fully span the occupied space! '
                'Max|<occ|LO><LO|occ>| = %e', lospanerr)
    raise RuntimeError
# check 2: Span(LO) == Span(occ)
occspanerr = abs(m @ m.T - np.eye(m.shape[0])).max()
if occspanerr < 1e-10:
    log.info('LOs span exactly the occupied space.')
    if no_type not in ['ir','ie']:
        log.error('"no_type" must be "ir" or "ie".')
        raise ValueError
else:
    log.info('LOs span occupied space plus some virtual space.')

if frag_lolist == '1o':
    log.info('Using single-LO fragment') # this is what we use, every active local occ labels a fragment
    frag_lolist = [[i] for i in range(orbloc.shape[1])]
else: print('Only support single LO fragment!')
nfrag = len(frag_lolist)

if not (no_type[0] in 'rei' and no_type[1] in 'rei'):
    log.warn('Input no_type "%s" is invalid.', no_type)
    raise ValueError

frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'# lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100*nfrag)

for ifrag in run_frg_list:
    print(f'\n########### running fragment {ifrag+1} ##########')

    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    # frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, can_orbfrag \
        = lno_ccsd.make_lno(lno_cc, orbfragloc, THRESH_INTERNAL, thresh_pno)
    
    mol = mf.mol
    nocc = mol.nelectron // 2 
    nao = mol.nao
    actfrag = np.array([i for i in range(nao) if i not in frzfrag])
    # frzocc = np.array([i for i in range(nocc) if i in frzfrag])
    actocc = np.array([i for i in range(nocc) if i in actfrag])
    actvir = np.array([i for i in range(nocc,nao) if i in actfrag])
    nactocc = len(actocc)
    nactocc = len(actocc)
    nactvir = len(actvir)
    prjlo = orbfragloc.T @ s1e @ orbfrag[:,actocc]

    print(f'# active orbitals: {actfrag}')
    print(f'# active occupied orbitals: {actocc}')
    print(f'# active virtual orbitals: {actvir}')
    print(f'# frozen orbitals: {frzfrag}')

    if full_cisd:
        # print('# This method is not size-extensive')
        frz_mo_idx = np.where(np.array(frozen_mask) == False)[0]
        act_mo_occ = np.array([i for i in range(nocc) if i not in frz_mo_idx])
        act_mo_vir = np.array([i for i in range(nocc,nao) if i not in frz_mo_idx])
        prj_no2mo = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag)
        prj_oo_act = prj_no2mo[np.ix_(act_mo_occ,actocc)]
        prj_vv_act = prj_no2mo[np.ix_(act_mo_vir,actvir)]
        if isinstance(mfcc, CCSD):
            print('# Use full CCSD wavefunction')
            print('# Project CC amplitudes from MO to NO')
            t1 = mfcc.t1
            t2 = mfcc.t2
            # project to active no
            t1 = lib.einsum("ij,ia,ba->jb",prj_oo_act,t1,prj_vv_act.T)
            t2 = lib.einsum("ik,jl,ijab,db,ca->klcd",
                    prj_oo_act,prj_oo_act,t2,prj_vv_act.T,prj_vv_act.T)
            ci1 = np.array(t1)
            ci2 = t2 + lib.einsum("ia,jb->ijab",ci1,ci1)
            ci2 = ci2.transpose(0, 2, 1, 3)
        if isinstance(mfcc, CISD):
            print('# Use full CISD wavefunction')
            print('# Project CI coefficients from MO to NO')
            v_ci = mfcc.ci
            ci0,ci1,ci2 = mfcc.cisdvec_to_amplitudes(v_ci)
            ci1 = ci1/ci0
            ci2 = ci2/ci0
            ci1 = lib.einsum("ij,ia,ba->jb",prj_oo_act,ci1,prj_vv_act.T)
            ci2 = lib.einsum("ik,jl,ijab,db,ca->klcd",
                    prj_oo_act,prj_oo_act,ci2,prj_vv_act.T,prj_vv_act.T)
            ci2 = ci2.transpose(0, 2, 1, 3)
        print('# Finished MO to NO projection')
        ecorr_ccsd = '  None  '
    else:
        print('# Solving LNO-CCSD')
        ecorr_ccsd,t1,t2 = lno_ccsd.cc_impurity_solve(
                mf,orbfrag,orbfragloc,frozen=frzfrag,eris=None
                )
        ci1 = np.array(t1)
        ci2 = t2 + lib.einsum("ia,jb->ijab",ci1,ci1)
        ci2 = ci2.transpose(0, 2, 1, 3)
        ecorr_ccsd = f'{ecorr_ccsd:.8f}'
        print(f'# lno-ccsd fragment correlation energy: {ecorr_ccsd}')

    nelec_act = nactocc*2
    norb_act = nactocc+nactvir
    
    print(f'# number of active electrons: {nelec_act}')
    print(f'# number of active orbitals: {norb_act}')
    print(f'# number of frozen orbitals: {len(frzfrag)}')

    options["seed"] = seeds[ifrag]
    lno_ccsd.prep_lno_amp_chol_file(
        mf,orbfrag,options,
        norb_act=norb_act,nelec_act=nelec_act,
        prjlo=prjlo,norb_frozen=frzfrag,
        ci1=ci1,ci2=ci2,use_df_vecs=use_df_vecs,
        chol_cut=chol_cut,
        )


lo_type = boys


******** <class 'pyscf.lo.boys.Boys'> ********
conv_tol = 1e-06
conv_tol_grad = None
max_cycle = 100
max_stepsize = 0.01


max_iters = 20
kf_interval = 5
kf_trust_region = 5
ah_start_tol = 1000000000.0
ah_start_cycle = 1
ah_level_shift = 0
ah_conv_tol = 1e-12
ah_lindep = 1e-14
ah_max_cycle = 40
ah_trust_region = 3
init_guess = atomic
Set conv_tol_grad to 0.000316228
macro= 1  f(x)= 11.642313747317  delta_f= 11.6423  |g|= 5.78794  1 KF 3 Hx
macro= 2  f(x)= 11.365198790146  delta_f= -0.277115  |g|= 5.09448  1 KF 3 Hx
macro= 3  f(x)= 11.125864884816  delta_f= -0.239334  |g|= 4.46837  1 KF 3 Hx
macro= 4  f(x)= 10.927922329461  delta_f= -0.197943  |g|= 3.77802  1 KF 3 Hx
macro= 5  f(x)= 10.77444009648  delta_f= -0.153482  |g|= 3.03256  1 KF 3 Hx
macro= 6  f(x)= 10.66783834402  delta_f= -0.106602  |g|= 2.24236  1 KF 3 Hx
macro= 7  f(x)= 10.609845053662  delta_f= -0.0579933  |g|= 1.41845  1 KF 3 Hx
macro= 8  f(x)= 10.598717407376  delta_f= -0.0111276  |g|= 0.572254  1 KF 3 Hx
macro= 9  f(x)= 10.59871735688  delta_f= -5.04956e-08  |g|= 0.00116685  1 KF 2 Hx
macro= 10  f(x)= 10.59871735688  delta_f= -1.06581e-14  |

In [10]:
from functools import partial
from jax import random
#from mpi4py import MPI
import numpy as np
from jax import numpy as jnp
from ad_afqmc import config, wavefunctions, stat_utils
from ad_afqmc.lno_ccsd import lno_ccsd
import time

print = partial(print, flush=True)

ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
    lno_ccsd.prep_lnoccsd_afqmc())

if options["use_gpu"]:
    config.afqmc_config["use_gpu"] = True

config.setup_jax()
MPI = config.setup_comm()

comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Process rank
size = comm.Get_size()  # Total number of processes

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 2
# nelec: (1, 1)
# nchol: 3
#
# n_eql: 4
# n_prop_steps: 10
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 10
# seed: 254
# walker_type: rhf
# trial: cisd
# dt: 0.005
# free_projection: True
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# n_batch: 1
# ene0: 0
#
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [12]:

### initialize propagation
seed = options["seed"]
propagator = prop
init_walkers = None
ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

In [None]:
def block_orb(prop_data: dict,
              ham_data: dict,
              prop: propagation.propagator,
              trial: wave_function,
              wave_data: dict,
              sampler: sampling.sampler):
        """Block scan function. Propagation and orbital_i energy calculation."""
        prop_data["key"], subkey = random.split(prop_data["key"])
        fields = random.normal(
            subkey,
            shape=(
                sampler.n_prop_steps,
                prop.n_walkers,
                ham_data["chol"].shape[0],
            ),
        )
        # propgate n_prop_steps x dt
        _step_scan_wrapper = lambda x, y: sampler._step_scan(
            x, y, ham_data, prop, trial, wave_data
        )
        prop_data, _ = lax.scan(_step_scan_wrapper, prop_data, fields)
        prop_data["n_killed_walkers"] += prop_data["weights"].size - jnp.count_nonzero(
            prop_data["weights"]
        )
        prop_data = prop.orthonormalize_walkers(prop_data)
        prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)

        hf_orb_cr, hf_orb_en \
            = frg_hf_cr(prop_data["walkers"],ham_data,wave_data,trial)
        cc_orb_cr \
            = frg_ccsd_cr(prop_data["walkers"], ham_data, wave_data, trial,1e-5)
        energy_samples = jnp.real(
            trial.calc_energy(prop_data["walkers"], ham_data, wave_data)
        )
        energy_samples = jnp.where(
            jnp.abs(energy_samples - prop_data["e_estimate"]) > jnp.sqrt(2.0 / prop.dt),
            prop_data["e_estimate"],
            energy_samples,
        )

        blk_wt = jnp.sum(prop_data["weights"])
        blk_en = jnp.sum(energy_samples * prop_data["weights"]) / blk_wt
        blk_hf_orb_en = jnp.sum(hf_orb_en * prop_data["weights"]) / blk_wt
        blk_hf_orb_cr = jnp.sum(hf_orb_cr * prop_data["weights"]) / blk_wt
        blk_cc_orb_cr = jnp.sum(cc_orb_cr * prop_data["weights"]) / blk_wt
        blk_cc_orb_en = blk_hf_orb_cr+blk_cc_orb_cr
        
        prop_data["pop_control_ene_shift"] = (
            0.9 * prop_data["pop_control_ene_shift"] + 0.1 * blk_en
        )

        return prop_data,(blk_en,blk_wt,blk_hf_orb_en,blk_cc_orb_en)

In [None]:
### relaxation ###
init_time = time.time()
comm.Barrier()
if rank == 0:
    hf_orb_cr,hf_orb_en = lno_ccsd._frg_hf_cr(
        ham_data['rot_h1'], ham_data['rot_chol'],prop_data["walkers"][0],trial,wave_data)
    cc_orb_cr = lno_ccsd._frg_ccsd_cr(
        prop_data["walkers"][0],ham_data,wave_data,trial,1e-6)
    
    cc_orb_en = hf_orb_cr+cc_orb_cr
    e_init = prop_data["e_estimate"]

    print(f'# afqmc propagation with {options["n_walkers"]*size} walkers')
    print(f'# appraching equilibrium')
    print('# step  energy  hf_orb_en  cc_orb_en  time')
    print(f"  {0:3d}"
          f"  {e_init:.6f}"
          f"  {hf_orb_en:.6f}"
          f"  {cc_orb_en:.6f}"
          f"  {time.time() - init_time:.2f} "
        )
comm.Barrier()

In [None]:
for n in range(options["n_eql"]):
    prop_data, (blk_en,blk_wt,blk_hf_orb_en,blk_cc_orb_en) \
                = lno_ccsd.propagate_phaseless_orb(
                    ham_data,prop,prop_data,trial,wave_data,sampler)
    
    blk_en = np.array([blk_en], dtype="float32")
    blk_wt = np.array([blk_wt], dtype="float32")
    blk_hf_orb_en = np.array([blk_hf_orb_en], dtype="float32")
    blk_cc_orb_en = np.array([blk_cc_orb_en], dtype="float32")    

    blk_wt_en = np.array(
        [blk_en * blk_wt], dtype="float32"
    )
    blk_wt_hf_orb_en = np.array(
        [blk_hf_orb_en * blk_wt], dtype="float32"
    )
    blk_wt_cc_orb_en = np.array(
        [blk_cc_orb_en * blk_wt], dtype="float32"
    )

    tot_wt_en = np.zeros(1, dtype="float32")
    tot_wt = np.zeros(1, dtype="float32")
    tot_wt_hf_orb_en = np.zeros(1, dtype="float32")
    tot_wt_cc_orb_en = np.zeros(1, dtype="float32")

    comm.Reduce(
            [blk_wt_en, MPI.FLOAT],
            [tot_wt_en, MPI.FLOAT],
            op=MPI.SUM,
            root=0,
        )
    comm.Reduce(
        [blk_wt, MPI.FLOAT],
        [tot_wt, MPI.FLOAT],
        op=MPI.SUM,
        root=0,
    )
    comm.Reduce(
            [blk_wt_hf_orb_en, MPI.FLOAT],
            [tot_wt_hf_orb_en, MPI.FLOAT],
            op=MPI.SUM,
            root=0,
        )
    comm.Reduce(
            [blk_wt_cc_orb_en, MPI.FLOAT],
            [tot_wt_cc_orb_en, MPI.FLOAT],
            op=MPI.SUM,
            root=0,
        )

    comm.Barrier()
    if rank == 0:
        blk_en = tot_wt_en / tot_wt
        blk_wt = tot_wt
        blk_hf_orb_en = tot_wt_hf_orb_en / tot_wt
        blk_cc_orb_en = tot_wt_cc_orb_en / tot_wt

    comm.Bcast(blk_en, root=0)
    comm.Bcast(blk_wt, root=0)
    comm.Bcast(blk_hf_orb_en, root=0)
    comm.Bcast(blk_cc_orb_en, root=0)
    
    prop_data = propagator.orthonormalize_walkers(prop_data)
    prop_data = propagator.stochastic_reconfiguration_global(prop_data, comm)
    prop_data["e_estimate"] = (
         0.9 * prop_data["e_estimate"] + 0.1 * blk_en[0]
         )
    comm.Barrier()

    comm.Barrier()
    if rank == 0:
        print(
            f"  {n+1:3d}"
            f"  {blk_en[0]:.6f}"
            f"  {blk_hf_orb_en[0]:.6f}"
            f"  {blk_cc_orb_en[0]:.6f}"
            f"  {time.time() - init_time:.2f} "
        )
    comm.Barrier()