In [1]:
from pyscf import gto, scf, cc
import numpy as np
from jax import numpy as jnp
from jax import vmap, jvp, jit
import jax
from functools import partial

a = 1.05835 # 2aB
nH = 6
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mol.build()

mf = scf.RHF(mol)#.density_fit()
e = mf.kernel()

mycc = cc.CCSD(mf)
e = mycc.kernel()

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Sat Oct 11 12:47:53 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 6
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.00

In [None]:
# AFQMC energy: -8.7671 +/- 0.0010

In [None]:
import numpy as np
np.set_printoptions(5)
# AFQMC/CCSD_PT energy: -1.096171 +/- 0.000554
# AFQMC/CCSD_PT energy: -2.192221 +/- 0.001636
# AFQMC/CCSD_PT energy: -4.384578 +/- 0.006862
# AFQMC/CCSD_PT energy: -8.769366 +/- 0.019849
# AFQMC/CCSD_PT energy: -17.538377 +/- 0.091299
# AFQMC/CCSD_PT energy: -27.405295 +/- 0.246627
nm = np.array([1,2,4,8,16,25])
ept = np.array([-1.096171,-2.192221,-4.384578,-8.769366,-17.538377,-27.405295])
ept_perm = ept/nm
print('number of H2 monomers:            ', nm)
print('AFQMC/CCSD_PT energy per monomer: ', ept_perm)

number of H2 monomers:             [ 1  2  4  8 16 25]
AFQMC/CCSD_PT energy per monomer:  [-1.09617 -1.09611 -1.09614 -1.09617 -1.09615 -1.09621]


In [2]:
from ad_afqmc.cisd_perturb import ccsd_pt
options = {'n_eql': 4,
           'n_prop_steps': 50,
            'n_ene_blocks': 20,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 40,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc import pyscf_interface
from ad_afqmc.cisd_perturb import sample_pt2
pyscf_interface.prep_afqmc(mycc,chol_cut=1e-7)

#
# Preparing AFQMC calculation
# If you import pyscf cc modules and use MPI for AFQMC in the same script, finalize MPI before calling the AFQMC driver.
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (3, 3)
# Number of basis functions: 30
# Number of Cholesky vectors: 191
#


In [3]:
from ad_afqmc import config, mpi_jax, wavefunctions
import time
from jax import random
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ \
    = (mpi_jax._prep_afqmc(options))

norb = trial.norb
chol = ham_data["chol"].reshape(-1, norb, norb)
h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
v0 = 0.5 * jnp.einsum("gik,gjk->ij",
                        chol.reshape(-1, norb, norb),
                        chol.reshape(-1, norb, norb),
                        optimize="optimal")

nocc = wave_data['ci1'].shape[0]
wave_data["mo_coeff"] = np.eye(norb)[:,:nocc]
ci1,ci2 = wave_data['ci1'],wave_data['ci2']
t2 = ci2 - jnp.einsum('ia,jb->iajb',ci1,ci1)
wave_data['t1'] = ci1
wave_data['t2'] = t2

h1_mod = h1 - v0 
ham_data['h1_mod'] = h1_mod

config.setup_jax()
MPI = config.setup_comm()
init = time.time()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
seed = options["seed"]
neql = options["n_eql"]

trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1
ham_data = wavefunctions.rhf(
    trial.norb, trial.nelec,n_batch=trial.n_batch
    )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(
    ham_data, prop, trial, wave_data
)
prop_data = prop.init_prop_data(trial, wave_data, ham_data, None)
prop_data["key"] = random.PRNGKey(seed + rank)

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 30
# nelec: (3, 3)
#
# n_eql: 4
# n_prop_steps: 50
# n_ene_blocks: 20
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 40
# seed: 2
# walker_type: rhf
# trial: cisd
# dt: 0.005
# free_projection: False
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# n_batch: 1
# LNO: False
# orbE: 0
# maxError: 0.001
#
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [None]:
def blocking_analysis(weights, energies, neql=0, printQ=False, writeBlockedQ=False):
    nSamples = weights.shape[0] - neql
    weights = weights[neql:]
    energies = energies[neql:]
    weightedEnergies = np.multiply(weights, energies)
    meanEnergy = weightedEnergies.sum() / weights.sum()
    if printQ:
        print(f"#\n# Mean: {meanEnergy:.8e}")
        print("# Block size    # of blocks         Mean                Error")
    blockSizes = np.array([1, 2, 3, 5, 10, 15, 20, 30, 50, 100, 200, 300, 400, 500, 1000, 10000])
    prevError = 0.0
    plateauError = None
    for i in blockSizes[blockSizes < nSamples / 2.0]:
        nBlocks = nSamples // i
        blockedWeights = np.zeros(nBlocks,dtype="float32")
        blockedEnergies = np.zeros(nBlocks,dtype="float32")#"complex64")
        for j in range(nBlocks):
            blockedWeights[j] = weights[j * i : (j + 1) * i].sum()
            blockedEnergies[j] = (
                weightedEnergies[j * i : (j + 1) * i].sum() / blockedWeights[j]
            )
        v1 = blockedWeights.sum()
        v2 = (blockedWeights**2).sum()
        mean = np.multiply(blockedWeights, blockedEnergies).sum() / v1
        error = (
            np.multiply(blockedWeights, (blockedEnergies - mean) ** 2).sum()
            / (v1 - v2 / v1)
            / (nBlocks - 1)
        ) ** 0.5
        if writeBlockedQ:
            np.savetxt(
                f"samples_blocked_{i}.dat",
                np.stack((blockedWeights, blockedEnergies)).T,
            )
        if printQ:
            print(f"  {i:5d}           {nBlocks:6d}       {mean:.8e}       {error:.6e}")
        if error <= 1.05 * prevError and plateauError is None:
            plateauError = max(error, prevError)
        prevError = error

    if printQ:
        if plateauError is not None:
            print(f"# Stocahstic error estimate: {plateauError:.6e}\n#")

    return meanEnergy, plateauError

In [9]:
h0 = ham_data['h0']
e0, e1, t = ccsd_pt._ccsd_walker_energy_pt(
    prop_data['walkers'][0],ham_data,wave_data,trial)
print(e0+h0)
print(e1)
print(t)

(-1.0892834338083683+0j)
(-0.04142308585635617+0j)
(8.231184445851372e-18+0j)


In [5]:
from ad_afqmc.cisd_perturb import sample_pt2
prop_data["n_killed_walkers"] = 0

In [6]:
prop_data, (blk_wt, blk_t, blk_e0, blk_e1) = \
    sample_pt2._block_scan(prop_data,ham_data,prop,trial,wave_data,sampler)

In [7]:
print(blk_wt)
print(blk_t, blk_e0, blk_e1)

39.864298549868025
0.02212654217667736 -7.613390945177146 -0.2625478474481564


In [9]:
sum(prop_data['weights'])

Array(39.86429855, dtype=float64)

In [10]:
rho =  prop_data['weights']/sum(prop_data['weights'])
e0, e1, t = ccsd_pt.ccsd_walker_energy_pt(
        prop_data["walkers"],ham_data,wave_data,trial)
rho_e0 = rho*e0

In [None]:
print(sum(rho_e0))

-7.613390945177144


In [14]:
print(np.mean(rho_e0))
print(np.std(rho_e0))

-0.19033477362942863
0.004987375822228835


In [18]:
h0 + blk_e0 + blk_e1 - blk_t * blk_e0

Array(-1.13092409+2.60668348e-05j, dtype=complex128)

In [78]:
from ad_afqmc.cisd_perturb import sample_pt2
prop_data["n_killed_walkers"] = 0
prop_data, (blk_wt, blk_t, blk_e0, blk_e1) = \
    sample_pt2._sr_block_scan(prop_data,ham_data,prop,trial,wave_data,sampler)

In [79]:
print(blk_wt)
print(blk_t)
print(blk_e0)
print(blk_e1)

[39.8643  39.75175 39.95041 39.84242 39.90087 39.91925 39.85964 39.89267
 39.83531 39.88537 40.012   39.83649 39.84835 39.89709 39.85604 39.93099
 39.99298 39.92618 39.9368  39.85308]
[0.02213 0.05203 0.06344 0.06884 0.08227 0.08173 0.07148 0.07985 0.07498
 0.08254 0.10617 0.10188 0.10868 0.08211 0.08513 0.08841 0.08676 0.08971
 0.09535 0.11229]
[-7.61339 -7.65532 -7.6684  -7.67835 -7.69217 -7.69571 -7.68553 -7.6872
 -7.67719 -7.68696 -7.70383 -7.69393 -7.69161 -7.66816 -7.67018 -7.67631
 -7.67297 -7.68156 -7.68697 -7.70173]
[-0.26255 -0.45117 -0.52716 -0.5588  -0.64903 -0.64177 -0.57226 -0.63571
 -0.60766 -0.65615 -0.82256 -0.79898 -0.85424 -0.67039 -0.6918  -0.71061
 -0.70197 -0.71608 -0.75405 -0.87203]


In [84]:
t = np.sum(blk_wt * blk_t)/np.sum(blk_wt)
e0 = np.sum(blk_wt * blk_e0)/np.sum(blk_wt)
e1 = np.sum(blk_wt * blk_e1)/np.sum(blk_wt)
# t_err = np.std(blk_t)
# e0_err = np.std(blk_t)
# e1_err = np.std(blk_t)

In [85]:
# E = e0 + e1 - t*e0
dE = np.array([-e0,1-t,1])

In [86]:
cov_te0e1 = np.cov([blk_t,blk_e0,blk_e1])
print(cov_te0e1)

[[ 0.00043 -0.00036 -0.00294]
 [-0.00036  0.00039  0.00236]
 [-0.00294  0.00236  0.02035]]


In [87]:
np.sqrt(dE @ cov_te0e1 @ dE)

0.0007600322264764268

In [None]:
cov_te0e1 = np.cov(blk_e0,blk_t)
cov_e0e1 = np.cov(blk_e0,blk_e1)
cov_e1t = np.cov(blk_e1,blk_t)

In [28]:
print(cov_e0t)
print(cov_e0e1)
print(cov_e1t)

[[ 6.67078e-05 -4.39983e-05]
 [-4.39983e-05  2.92440e-05]]
[[6.67078e-05 4.99679e-06]
 [4.99679e-06 9.78546e-07]]
[[ 9.78546e-07 -3.66370e-06]
 [-3.66370e-06  2.92440e-05]]


In [None]:
e0**2*t_err**2 + (1-t)**2*e0_err**2

Array(8.38046e-05, dtype=float64)

In [41]:
2*e0*(1-t)*cov_e0t[0][1]

Array(0.00014, dtype=float64)

In [None]:
ept_err = jnp.sqrt(
        e0**2*t_err**2 + (1-t)**2*e0_err**2 + e1_err**2
        + 2*e0*(1-t)*cov_e0t[0][1] + 2*(1-t)*cov_e0e1[0][1] + 2*e0*cov_e1t[0][1])

In [39]:
e0**2*t_err**2 + (1-t)**2*e0_err**2 + e1_err**2 

Array(0.00011, dtype=float64)

In [40]:
- 2*e0*(1-t)*cov_e0t[0][1] - 2*(1-t)*cov_e0e1[0][1] - 2*e0*cov_e1t[0][1]

Array(-0.00016, dtype=float64)

In [34]:
ept_err_n = jnp.sqrt(
        e0**2*t_err**2 + (1-t)**2*e0_err**2 + e1_err**2)

In [37]:
print(ept_err)
print(ept_err_n)

nan
0.010353734951627976


In [30]:
wt = jnp.sum(blk_wt)
t = jnp.sum(blk_t * blk_wt) / wt
e0 = jnp.sum(blk_e0 * blk_wt) / wt
e1 = jnp.sum(blk_e1 * blk_wt) / wt
print(wt)
print(t)
print(e0)
print(e1)
print(h0 + e0 + e1 - t*e0)

199.74337761096814
(0.02770313707358223-0.005324835443396208j)
(-1.627140681997922+0.006549220259902778j)
(-0.0487240968399901+0.0021310646102316315j)
(-1.1308206619813184-0.00016540545134517172j)


In [32]:
prop_data, (wt, t, e0, e1) = \
    sample_pt2.propagate_phaseless(prop_data,ham_data,prop,trial,wave_data,sampler)
print(wt)
print(t)
print(e0)
print(e1)
print(h0 + e0 + e1 - t*e0)

1998.06851006178
(0.03371363519748113+0.0030895190685928105j)
(-1.632272124020626-0.0020501796440972747j)
(-0.05339024327780739-0.0029899172325453726j)
(-1.130636785409919+7.195798426196283e-05j)


In [34]:
sum(prop_data['weights'])

Array(39.98750123, dtype=float64)