In [1]:
from pyscf import gto, scf, cc
import numpy as np
from jax import numpy as jnp
from jax import vmap, jvp, jit
import jax
from functools import partial

a = 1.05835 # 2aB
nH = 6
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="sto6g", verbose=4)
mol.build()

mf = scf.RHF(mol)#.density_fit()
e = mf.kernel()

mycc = cc.CCSD(mf)
e = mycc.kernel()

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Wed Oct 15 21:25:41 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 6
[INPUT] num. electrons = 6
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.00

In [None]:
# AFQMC energy: -8.7671 +/- 0.0010

In [None]:
# import numpy as np
# np.set_printoptions(5)
# # AFQMC/CCSD_PT energy: -1.096171 +/- 0.000554
# # AFQMC/CCSD_PT energy: -2.192221 +/- 0.001636
# # AFQMC/CCSD_PT energy: -4.384578 +/- 0.006862
# # AFQMC/CCSD_PT energy: -8.769366 +/- 0.019849
# # AFQMC/CCSD_PT energy: -17.538377 +/- 0.091299
# # AFQMC/CCSD_PT energy: -27.405295 +/- 0.246627
# nm = np.array([1,2,4,8,16,25])
# ept = np.array([-1.096171,-2.192221,-4.384578,-8.769366,-17.538377,-27.405295])
# ept_perm = ept/nm
# print('number of H2 monomers:            ', nm)
# print('AFQMC/CCSD_PT energy per monomer: ', ept_perm)

number of H2 monomers:             [ 1  2  4  8 16 25]
AFQMC/CCSD_PT energy per monomer:  [-1.09617 -1.09611 -1.09614 -1.09617 -1.09615 -1.09621]


In [16]:
options = {'n_eql': 4,
           'n_prop_steps': 50,
            'n_ene_blocks': 20,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 2,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'cisd',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc import pyscf_interface
from ad_afqmc.cisd_perturb import sample_ccsd_pt, ccsd_pt
mycc.t1 = mycc.t1*0
ccsd_pt.prep_afqmc(mycc,chol_cut=1e-5)

#
# Preparing AFQMC calculation
# If you import pyscf cc modules and use MPI for AFQMC in the same script, finalize MPI before calling the AFQMC driver.
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (3, 3)
# Number of basis functions: 6
# Number of Cholesky vectors: 15
#


In [11]:
from jax import lax
@jax.jit
def _t1t2_walker_olp(walker,wave_data):
    ''' <psi_0(t1+t2)|phi> '''
    t1, t2 = wave_data['t1'], wave_data['t2']
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", t1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", t2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", t2, GF[:, nocc:], GF[:, nocc:])
    return (2*o1 + o2) * o0

@jax.jit
def _t1t2_olp_exp1(x: float, h1_mod: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    t_ia <psi_i^a|exp(x*h1_mod)|walker>
    '''
    t = x * h1_mod
    walker_1x = walker + t.dot(walker)
    olp = _t1t2_walker_olp(walker_1x,wave_data)
    return olp

@jax.jit
def _t1t2_olp_exp2(x: float, chol_i: jax.Array, walker: jax.Array,
                  wave_data: dict) -> complex:
    '''
    t_ia <psi_i^a|exp(x*h2_mod)|walker>
    '''
    walker_2x = (
            walker
            + x * chol_i.dot(walker)
            + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker))
        )
    olp = _t1t2_walker_olp(walker_2x,wave_data)
    return olp

@partial(jit, static_argnums=3)
def _et1t2(walker, ham_data, wave_data, trial, eps=3e-4):
    ''' <psi_0|(t1+t2)(h1+h2)|phi>/<psi_0|phi> '''
    norb = trial.norb
    # h0 = ham_data['h0']
    h1_mod = ham_data['h1_mod']
    chol = ham_data["chol"].reshape(-1, norb, norb)

    # one body
    x = 0.0
    f1 = lambda a: _t1t2_olp_exp1(a,h1_mod,walker,wave_data)
    olp_t, d_olp = jvp(f1, [x], [1.0])

    # two body
    # c_ij^ab <psi_ij^ab|phi_2x>/<psi_0|phi>
    def scanned_fun(carry, c):
        eps, walker, wave_data = carry
        return carry, _t1t2_olp_exp2(eps,c,walker,wave_data)

    _, olp_p = lax.scan(scanned_fun, (eps, walker, wave_data), chol)
    _, olp_0 = lax.scan(scanned_fun, (0.0, walker, wave_data), chol)
    _, olp_m = lax.scan(scanned_fun, (-1.0 * eps, walker, wave_data), chol)
    d_2_olp = (olp_p - 2.0 * olp_0 + olp_m) / eps / eps
    
    o0 = trial._calc_overlap_restricted(walker, wave_data)
    t = olp_t/o0
    e0 = trial._calc_energy_restricted(walker,ham_data,wave_data)
    e1 = (d_olp + jnp.sum(d_2_olp) / 2.0 ) / o0

    return jnp.real(t), jnp.real(e0), jnp.real(e1)

In [13]:
def _cal_t1t2(trial,walker,wave_data):
    nocc = trial.nelec[0]
    green = (walker.dot(jnp.linalg.inv(walker[:nocc, :]))).T
    green = green[:nocc,:]
    green_occ = green[:, nocc:].copy()
    t1 = wave_data['t1']
    t2 = wave_data['t2']
    t1g = jnp.einsum("ia,ia->", t1, green_occ, optimize="optimal")
    t2g_c = jnp.einsum("iajb,ia->jb", t2, green_occ)
    t2g_e = jnp.einsum("iajb,ib->ja", t2, green_occ)
    t2g = 2 * t2g_c - t2g_e
    gt2g = jnp.einsum("ia,ia->", t2g, green_occ, optimize="optimal")
    t = 2 * t1g + gt2g
    return jnp.real(t)

In [17]:
from ad_afqmc import config, mpi_jax, wavefunctions
import time
from jax import random
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ \
    = (mpi_jax._prep_afqmc(options))
trial = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch)

norb = trial.norb
chol = ham_data["chol"].reshape(-1, norb, norb)
h1 = (ham_data["h1"][0] + ham_data["h1"][1]) / 2.0
v0 = 0.5 * jnp.einsum("gik,gjk->ij",
                        chol.reshape(-1, norb, norb),
                        chol.reshape(-1, norb, norb),
                        optimize="optimal")

nocc = wave_data['ci1'].shape[0]
wave_data["mo_coeff"] = np.eye(norb)[:,:nocc]
ci1,ci2 = wave_data['ci1'],wave_data['ci2']
t2 = ci2
wave_data['t1'] = ci1
wave_data['t2'] = t2

h1_mod = h1 - v0 
ham_data['h1_mod'] = h1_mod
h0 = ham_data['h0']

config.setup_jax()
MPI = config.setup_comm()
init = time.time()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
seed = options["seed"]
neql = options["n_eql"]

trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1

ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(
    ham_data, prop, trial, wave_data
)
prop_data = prop.init_prop_data(trial, wave_data, ham_data, None)
prop_data["key"] = random.PRNGKey(seed + rank)

# Number of MPI ranks: 1
#
# norb: 6
# nelec: (3, 3)
#
# n_eql: 4
# n_prop_steps: 50
# n_ene_blocks: 20
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 2
# seed: 2
# walker_type: rhf
# trial: cisd
# dt: 0.005
# free_projection: False
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# n_batch: 1
# LNO: False
# orbE: 0
# maxError: 0.001
#
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [18]:
print('mf energy: ', mf.e_tot)
eris = mycc.ao2mo(mycc.mo_coeff)
eccsd = mycc.energy(mycc.t1*0, mycc.t2, eris)
print('ccsd energy with t1*0 t2: ', mf.e_tot+eccsd)

mf energy:  -3.125545148766105
ccsd energy with t1*0 t2:  -3.238323737449727


In [19]:
walker = prop_data['walkers'][0]
t, e0, e1 = _et1t2(walker, ham_data, wave_data, trial, eps=3e-4)
print(e0-mf.e_tot)
print(e0 + e1 - t*(e0-h0))
print(e0 + e1 - t*(e0-h0) - (mf.e_tot+eccsd))
print(t, e0, e1)

1.2522667347525385e-08
-3.2383237346961025
2.7536244395776066e-09
-4.926548544408139e-32 -3.1255451362434377 -0.11277859845266498


In [20]:
_cal_t1t2(trial,walker,wave_data)

Array(-4.92654854e-32, dtype=float64)

In [22]:
e0, e1, t = ccsd_pt._ccsd_walker_energy_pt(walker,ham_data,wave_data,trial)
print(h0 + e0-mf.e_tot)
print(h0 + e0 + e1 - t*e0)
print(h0 + e0 + e1 - t*e0 - (mf.e_tot+eccsd))

2.1550087847543864e-07
-3.2383235317178913
2.0573183556749086e-07


In [5]:
from ad_afqmc.cisd_perturb import sample_pt2
prop_data["n_killed_walkers"] = 0

In [6]:
prop_data, (blk_wt, blk_t, blk_e0, blk_e1) = \
    sample_pt2._block_scan(prop_data,ham_data,prop,trial,wave_data,sampler)

In [86]:
cov_te0e1 = np.cov([blk_t,blk_e0,blk_e1])
print(cov_te0e1)

[[ 0.00043 -0.00036 -0.00294]
 [-0.00036  0.00039  0.00236]
 [-0.00294  0.00236  0.02035]]


In [87]:
np.sqrt(dE @ cov_te0e1 @ dE)

0.0007600322264764268

In [None]:
cov_te0e1 = np.cov(blk_e0,blk_t)
cov_e0e1 = np.cov(blk_e0,blk_e1)
cov_e1t = np.cov(blk_e1,blk_t)

In [28]:
print(cov_e0t)
print(cov_e0e1)
print(cov_e1t)

[[ 6.67078e-05 -4.39983e-05]
 [-4.39983e-05  2.92440e-05]]
[[6.67078e-05 4.99679e-06]
 [4.99679e-06 9.78546e-07]]
[[ 9.78546e-07 -3.66370e-06]
 [-3.66370e-06  2.92440e-05]]
