In [1]:
from pyscf import gto, scf, cc
import numpy as np
from jax import numpy as jnp
from jax import vmap, jvp, jit
import jax
from functools import partial

a = 1.05835 # 2aB
nH = 2
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mol.build()

mf = scf.RHF(mol)#.density_fit()
e = mf.kernel()

mycc = cc.CCSD(mf)
e = mycc.kernel()

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Thu Oct  9 22:14:27 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 2
[INPUT] num. electrons = 2
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.00

In [2]:
from ad_afqmc.cisd_perturb import ccsd_pt
options = {'n_eql': 4,
           'n_prop_steps': 50,
            'n_ene_blocks': 5,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 40,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc import pyscf_interface
from ad_afqmc.cisd_perturb import sample_pt2
pyscf_interface.prep_afqmc(mycc,chol_cut=1e-7)

#
# Preparing AFQMC calculation
# If you import pyscf cc modules and use MPI for AFQMC in the same script, finalize MPI before calling the AFQMC driver.
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (1, 1)
# Number of basis functions: 10
# Number of Cholesky vectors: 53
#


In [8]:
from ad_afqmc import config, mpi_jax, wavefunctions
import time
from jax import random
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ \
    = (mpi_jax._prep_afqmc())

norb = trial.norb
chol = ham_data["chol"].reshape(-1, norb, norb)
h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
v0 = 0.5 * jnp.einsum("gik,gjk->ij",
                        chol.reshape(-1, norb, norb),
                        chol.reshape(-1, norb, norb),
                        optimize="optimal")

nocc = wave_data['ci1'].shape[0]
wave_data["mo_coeff"] = np.eye(norb)[:,:nocc]
ci1,ci2 = wave_data['ci1'],wave_data['ci2']
t2 = ci2 - jnp.einsum('ia,jb->iajb',ci1,ci1)
wave_data['t1'] = ci1
wave_data['t2'] = t2

h1_mod = h1 - v0 
ham_data['h1_mod'] = h1_mod

config.setup_jax()
MPI = config.setup_comm()
init = time.time()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
seed = options["seed"]
neql = options["n_eql"]

trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1
ham_data = wavefunctions.rhf(
    trial.norb, trial.nelec,n_batch=trial.n_batch
    )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(
    ham_data, prop, trial, wave_data
)
prop_data = prop.init_prop_data(trial, wave_data, ham_data, None)
prop_data["key"] = random.PRNGKey(seed + rank)

# Number of MPI ranks: 1
#
# norb: 10
# nelec: (1, 1)
#
# n_eql: 4
# n_prop_steps: 50
# n_ene_blocks: 5
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 40
# seed: 2
# walker_type: rhf
# trial: cisd
# dt: 0.005
# free_projection: False
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# n_batch: 1
# ene0: 0
# LNO: False
# orbE: 0
# maxError: 0.001
#
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [9]:
h0 = ham_data['h0']
e0, e1, t = ccsd_pt._ccsd_walker_energy_pt(
    prop_data['walkers'][0],ham_data,wave_data,trial)
print(e0+h0)
print(e1)
print(t)

(-1.0892834338083683+0j)
(-0.04142308585635617+0j)
(8.231184445851372e-18+0j)


In [11]:
from ad_afqmc.cisd_perturb import sample_pt2
prop_data["n_killed_walkers"] = 0

In [12]:
prop_data, (blk_wt, blk_t, blk_e0, blk_e1) = \
    sample_pt2._block_scan(prop_data,ham_data,prop,trial,wave_data,sampler)

In [14]:
print(blk_wt)
print(blk_t, blk_e0, blk_e1)

39.937979554824814
(0.007980678814872811-0.0011422434553456228j) (-1.6025931421790407+0.0006389937738218602j) (-0.0411220839442803+0.0012227241932843069j)


In [18]:
h0 + blk_e0 + blk_e1 - blk_t * blk_e0

Array(-1.13092409+2.60668348e-05j, dtype=complex128)

In [29]:
prop_data, (blk_wt, blk_t, blk_e0, blk_e1) = \
    sample_pt2._sr_block_scan(prop_data,ham_data,prop,trial,wave_data,sampler)

In [27]:
print(blk_wt)
print(blk_t, blk_e0, blk_e1)

[39.95968819 39.93548592 39.9854366  39.97233379 39.95363817]
[0.02082484-0.00142377j 0.01810244-0.0013127j  0.02462256-0.00200433j
 0.0253741 +0.00107529j 0.02426512-0.00416078j] [-1.61916311+0.00120948j -1.61632982+0.0013671j  -1.62516785+0.00548413j
 -1.62464365+0.00162736j -1.62409738+0.00537751j] [-0.04548985+0.00110913j -0.04388507+0.00077185j -0.04567555-0.00222289j
 -0.04742279-0.00338289j -0.04616288+0.00140441j]


In [30]:
wt = jnp.sum(blk_wt)
t = jnp.sum(blk_t * blk_wt) / wt
e0 = jnp.sum(blk_e0 * blk_wt) / wt
e1 = jnp.sum(blk_e1 * blk_wt) / wt
print(wt)
print(t)
print(e0)
print(e1)
print(h0 + e0 + e1 - t*e0)

199.74337761096814
(0.02770313707358223-0.005324835443396208j)
(-1.627140681997922+0.006549220259902778j)
(-0.0487240968399901+0.0021310646102316315j)
(-1.1308206619813184-0.00016540545134517172j)


In [32]:
prop_data, (wt, t, e0, e1) = \
    sample_pt2.propagate_phaseless(prop_data,ham_data,prop,trial,wave_data,sampler)
print(wt)
print(t)
print(e0)
print(e1)
print(h0 + e0 + e1 - t*e0)

1998.06851006178
(0.03371363519748113+0.0030895190685928105j)
(-1.632272124020626-0.0020501796440972747j)
(-0.05339024327780739-0.0029899172325453726j)
(-1.130636785409919+7.195798426196283e-05j)


In [34]:
sum(prop_data['weights'])

Array(39.98750123, dtype=float64)