In [1]:
from pyscf import gto, scf, cc
import numpy as np
from jax import numpy as jnp
from jax import vmap, jvp, jit
import jax
from functools import partial

a = 2 # 2aB
nH = 4
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="sto6g", unit='bohr', spin=0, verbose=4)
mol.build()

mf = scf.UHF(mol)
mf.kernel()

s1e = mf.get_ovlp()
olp = mf.mo_coeff[0].T @ s1e @ mf.mo_coeff[1]
sign = np.array(np.sign(olp.diagonal()), dtype=int)
print('<A|B> sign: ', sign)
if -1 not in sign:
    mf.mo_coeff[1][:,1] = -mf.mo_coeff[1][:,1]
olp = mf.mo_coeff[0].T @ s1e @ mf.mo_coeff[1]
sign = np.array(np.sign(olp.diagonal()), dtype=int)
print('new <A|B> sign: ', sign)

nfrozen = 0
mycc = cc.CCSD(mf,frozen=nfrozen)
mycc.kernel()[0]

System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#5074-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Sat Oct 25 16:48:09 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 4
[INPUT] num. electrons = 4
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = bohr
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.000000

-0.07659940081305099

In [2]:
eris = mycc.ao2mo(mycc.mo_coeff)
eccs = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print(mf.e_tot)
print(mf.e_tot+eccs)
mycc.t1 = (mycc.t1[0]*10,mycc.t1[1]*10)
eccs = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print(mf.e_tot+eccs)
eccsd = mycc.energy(mycc.t1, mycc.t2, eris)
print(mf.e_tot+eccsd)

-2.088692381947721
-2.088694274795575
-2.088879997348302
-2.165477505313499


In [11]:
def thouless_trans(t1):
    ''' thouless transformation |psi'> = exp(t1)|psi>
        gives the transformed mo_occrep in the 
        original mo basis <psi_p|psi'_i>
        t = t_ia
        t_ia = c_ik c.T_ka
        c_ik = <psi_i|psi'_k>
    '''
    q, r = jnp.linalg.qr(t1,mode='complete')
    u_ji = q
    u_ai = r.T
    u_occ = jnp.vstack((u_ji,u_ai))
    q, r = jnp.linalg.qr(u_occ)
    # sgn = np.sign(r.diagonal())
    sgn = np.sign((q).diagonal())
    # choose the mo_t s.t. has 
    # positive olp with the original mo
    mo_t = np.einsum("ij,j->ij", q, sgn)
    return mo_t

In [3]:
options = {'n_eql': 4,
           'n_prop_steps': 50,
            'n_ene_blocks': 20,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 3,
            'seed': 2,
            'walker_type': 'rhf',
            'trial': 'uccsd_pt2_ad',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc import pyscf_interface
# from ad_afqmc.ccsd_pt import ccsd_pt, sample_ccsd_pt
# t1 = 5 * mycc.t1
# mycc.t1 = [10*mycc.t1[0],10*mycc.t1[1]]
# ccsd_pt.prep_afqmc(mycc,chol_cut=1e-7)
pyscf_interface.prep_afqmc(mycc,options,chol_cut=1e-6)

from ad_afqmc import config, mpi_jax
import time
from jax import random
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ \
    = (mpi_jax._prep_afqmc(options))

import h5py
chol_file='FCIDUMP_chol'
with h5py.File(chol_file, "r") as fh5:
    [nelec, nmo, ms, nchol] = fh5["header"]
    h0 = jnp.array(fh5.get("energy_core"))
    h1 = jnp.array(fh5.get("hcore")).reshape(nmo, nmo)
    h1_mod = jnp.array(fh5.get("hcore_mod")).reshape(nmo, nmo)
    chol = jnp.array(fh5.get("chol")).reshape(-1, nmo, nmo)


#
# Preparing AFQMC calculation
# If you import pyscf cc modules and use MPI for AFQMC in the same script, finalize MPI before calling the AFQMC driver.
[[[ 0.28971215 -0.527348    0.7357868   0.63221951]
  [ 0.43257524 -0.38371943 -0.49700052 -1.04656273]
  [ 0.43257524  0.38371943 -0.49700052  1.04656273]
  [ 0.28971215  0.527348    0.7357868  -0.63221951]]

 [[ 0.2897204   0.52735258  0.73578355  0.63221569]
  [ 0.43256967  0.38371184 -0.49700537 -1.04656551]
  [ 0.43256967 -0.38371184 -0.49700537  1.04656551]
  [ 0.2897204  -0.52735258  0.73578355 -0.63221569]]]
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (2, 2)
# Number of basis functions: 4
# Number of Cholesky vectors: 9
#
[[[ 0.28971215 -0.527348    0.7357868   0.63221951]
  [ 0.43257524 -0.38371943 -0.49700052 -1.04656273]
  [ 0.43257524  0.38371943 -0.49700052  1.04656273]
  [ 0.28971215  0.527348    0.7357868  -0.63221951]]

 [[ 0.289720

In [None]:
# q, r = jnp.linalg.qr(mycc.t1[1],mode='complete')
# u_ji = q
# u_ai = r.T
# u_occ = jnp.vstack((u_ji,u_ai))
# q, r = jnp.linalg.qr(np.eye(u_occ.shape[0]).T @ u_occ)
# np.sign((np.eye(u_occ.shape[0]).T @ q).diagonal())
# print(q)
# print(r)

[[-9.57848552e-01  2.96476418e-13]
 [-2.82649309e-13 -9.73694765e-01]
 [ 2.87273652e-01 -2.74455326e-15]
 [-4.19290267e-14 -2.27856324e-01]]
[[ 1.04400638e+00 -9.81189819e-15]
 [ 0.00000000e+00 -1.02701589e+00]]


In [None]:
# (mo_aA, mo_bA) = jnp.array(np.load('mo_coeff.npz')["mo_coeff"])
# print(mo_aA)
# print(mo_bA)
# print(np.sign((mo_aA.T @ mo_bA).diagonal()))

[[ 1.00000000e+00 -2.42861287e-17  2.20218299e-16 -2.49800181e-16]
 [ 2.42861287e-17  1.00000000e+00 -1.66533454e-16  5.63892150e-17]
 [-2.20218299e-16  1.66533454e-16  1.00000000e+00 -2.74086309e-16]
 [ 2.49800181e-16 -5.63892150e-17  2.74086309e-16  1.00000000e+00]]
[[ 1.00000000e+00 -1.24554863e-15 -1.12127101e-05  2.22039179e-16]
 [-1.24553146e-15 -1.00000000e+00  1.38779299e-15 -7.24497898e-06]
 [ 1.12127101e-05  1.38778135e-15  1.00000000e+00 -3.20909288e-16]
 [-2.22044605e-16 -7.24497898e-06  3.20921832e-16  1.00000000e+00]]
[ 1. -1.  1.  1.]


In [None]:
# sgn = np.sign((mo_aA.T @ mo_bA).diagonal())
# mo_bA = np.einsum('pq,q->pq',mo_bA,sgn)
# print(mo_bA)
# print(np.sign((mo_aA.T @ mo_bA).diagonal()))

[[ 1.00000000e+00  1.24554863e-15 -1.12127101e-05  2.22039179e-16]
 [-1.24553146e-15  1.00000000e+00  1.38779299e-15 -7.24497898e-06]
 [ 1.12127101e-05 -1.38778135e-15  1.00000000e+00 -3.20909288e-16]
 [-2.22044605e-16  7.24497898e-06  3.20921832e-16  1.00000000e+00]]
[1. 1. 1. 1.]


In [None]:
# noccA, noccB = trial.nelec[0], trial.nelec[1]
# wave_data['mo_A2B'] = mo_bA.T
# wave_data['mo_coeff'] = [mo_aA[:,:noccA],
#                          mo_bA[:,:noccB]]

In [None]:
# wave_data['mo_ta'] = thouless_trans(mycc.t1[0])
# wave_data['mo_tb'] = thouless_trans(mycc.t1[1])
# print(wave_data['mo_ta'])
# print(wave_data['mo_tb'])

[[ 9.57980957e-01 -7.82245435e-15]
 [ 7.42228747e-15  9.73626847e-01]
 [-2.86831809e-01  8.35197038e-17]
 [ 1.27628623e-15  2.28146363e-01]]
[[ 9.57848552e-01  2.00218081e-13]
 [-1.86544548e-13  9.73694765e-01]
 [-2.87273652e-01 -3.17533290e-15]
 [ 4.85101236e-14 -2.27856324e-01]]


In [6]:
# # ham_data['h1_mod'] = h1_mod
# # mo_coeff = jnp.array(np.load('mo_coeff.npz')["mo_coeff"])
# # transform walker from A basis to B basis
# # |phi_i> = C^A_pi |A_p>
# #         = C^A_pi |B_q><B_q|A_p>
# #         = C^BA_qp C^A_pi |B_q>
# # s1e = mf.get_ovlp()
# # mo_A2B = mf.mo_coeff[1].T @ s1e @ mf.mo_coeff[0]
# # wave_data['mo_A2B'] = mo_A2B

t2aa = mycc.t2[0] # + 2 * np.einsum("ia,jb->ijab", mycc.t1[0], mycc.t1[0])
# t2aa = (t2aa - t2aa.transpose(0, 1, 3, 2)) / 2
t2aa = t2aa.transpose(0, 2, 1, 3)
t2bb = mycc.t2[2] # + 2 * np.einsum("ia,jb->ijab", mycc.t1[1], mycc.t1[1])
# t2bb = (t2bb - t2bb.transpose(0, 1, 3, 2)) / 2
t2bb = t2bb.transpose(0, 2, 1, 3)
t2ab = mycc.t2[1] # + np.einsum("ia,jb->ijab", mycc.t1[0], mycc.t1[1])
t2ab = t2ab.transpose(0, 2, 1, 3)
t1a = np.array(mycc.t1[0])
t1b = np.array(mycc.t1[1])

# mo_ta = thouless_trans(t1a)
# mo_tb = thouless_trans(t1b)
noccA, noccB = trial.nelec[0], trial.nelec[1]
# wave_data['mo_ta'] = mo_ta[:,:noccA]
# wave_data['mo_tb'] = mo_tb[:,:noccB]
wave_data['mo_tb_A'] = wave_data['mo_A2B'].T @ wave_data['mo_tb']
mo_a_A = wave_data['mo_ta'] # @ wave_data['mo_coeff'][0]
mo_b_B = wave_data['mo_tb'] # @ wave_data["mo_B"].T @ wave_data['mo_coeff'][1]

# wave_data["rot_t1A"] = mo_a_A[:noccA,:noccA].T @ t1a
wave_data["rot_t2AA"] = jnp.einsum('ik,jl,kalb->iajb',
    mo_a_A[:noccA,:noccA].T,mo_a_A[:noccA,:noccA].T,t2aa)
# wave_data["rot_t1B"] = mo_b_B[:noccB,:noccB].T @ t1b
wave_data["rot_t2BB"] = jnp.einsum('ik,jl,kalb->iajb',
    mo_b_B[:noccB,:noccB].T,mo_b_B[:noccB,:noccB].T,t2bb)
wave_data["rot_t2AB"] = jnp.einsum('ik,jl,kalb->iajb',
    mo_a_A[:noccA,:noccA].T,mo_b_B[:noccB,:noccB].T,t2ab)

In [7]:
config.setup_jax()
MPI = config.setup_comm()
init = time.time()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
seed = options["seed"]
neql = options["n_eql"]

trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1

ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(
    ham_data, prop, trial, wave_data
)
prop_data = prop.init_prop_data(trial, wave_data, ham_data, None)

prop_data["key"] = random.PRNGKey(seed + rank)
print('init walker energy: ', prop_data['e_estimate'])
print('mf enegry: ', mf.e_tot)
print('err', mf.e_tot - prop_data['e_estimate'])
# walker_up = prop_data['walkers'][0][0]
# walker_dn = prop_data['walkers'][1][0]
et1 = trial._calc_energy(
    wave_data['mo_ta'],wave_data["mo_A2B"].T @ wave_data['mo_tb'],ham_data,wave_data
    )
print('exact T1 transformed init walker energy: ', et1)
# print(trial._calc_energy(wave_data['mo_ta'],wave_data["mo_B"]@wave_data['mo_tb'],ham_data,wave_data))
eris = mycc.ao2mo(mycc.mo_coeff)
eccsd = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print('ccs energy: ', mf.e_tot+eccsd)
print('ccsd energy: ', mycc.e_tot)

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
init walker energy:  -2.0886923811427613
mf enegry:  -2.088692381947721
err -8.049596544879023e-10
exact T1 transformed init walker energy:  -2.0888799980568957
ccs energy:  -2.088879997348302
ccsd energy:  -2.165291782760772


In [None]:
@partial(jit, static_argnums=0)
def _tls_olp(
    trial,
    walker_up: jax.Array,
    walker_dn: jax.Array,
    wave_data: dict,
) -> complex:
    '''<exp(T1)HF|walker>'''
    # everything in alpha basis
    # walker_dn = wave_data['mo_B'].T @ walker_dn
    olp = jnp.linalg.det(
        wave_data["mo_ta"].T.conj() @ walker_up
    ) * jnp.linalg.det(wave_data["mo_tb_A"].T.conj() @ walker_dn)

    return olp

@partial(jit, static_argnums=5)
def _tls_exp1(x: float, h1_mod: jax.Array, walker_up: jax.Array,
                    walker_dn: jax.Array, wave_data: dict, trial):
    '''
    unrestricted <ep(T1)HF|exp(x*h1_mod)|walker>/<HF|walker>
    '''
    walker_up_1x = walker_up + x * h1_mod.dot(walker_up)
    walker_dn_1x = walker_dn + x * h1_mod.dot(walker_dn)
    # walker_dn_1x = wave_data['mo_A2B'].T @ walker_dn_1x
    olp = _tls_olp(trial,walker_up_1x, walker_dn_1x, wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    return olp/o0

@partial(jit, static_argnums=5)
def _tls_exp2(x: float, chol_i: jax.Array, walker_up: jax.Array,
                   walker_dn: jax.Array, wave_data: dict, trial) -> complex:
    '''
    <exp(T1)HF|exp(x*h2_mod)|walker>/<HF|walker>
    '''

    walker_up_2x = (
        walker_up
        + x * chol_i.dot(walker_up)
        + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker_up))
    )
    walker_dn_2x = (
        walker_dn
        + x * chol_i.dot(walker_dn)
        + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker_dn))
    )
    # walker_dn_2x = wave_data['mo_B'].T @ walker_dn_2x
    olp = _tls_olp(trial, walker_up_2x,walker_dn_2x,wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)
    
    return olp/o0

In [None]:
from jax import lax
@partial(jit, static_argnums=0)
def _ut2_walker_olp(
    trial, walker_up: jax.Array, walker_dn: jax.Array, wave_data: dict
) -> complex:
    '''<exp(T1)HF|(t1+t2)|walker> = (t_ia G_ia + t_iajb G_iajb) * <exp(T1)HF|walker>'''
    noccA, t2AA = trial.nelec[0], wave_data["rot_t2AA"]
    noccB, t2BB = trial.nelec[1], wave_data["rot_t2BB"]
    t2AB = wave_data["rot_t2AB"]
    mo_A = wave_data['mo_ta'] # in alpha basis
    mo_B = wave_data['mo_tb'] # in beta basis
    green_a = (
        walker_up.dot(jnp.linalg.inv(mo_A.T.conj() @ walker_up))
    ).T
    # convert walker_dn into beta basis
    # s.t. green_b can be contracted with beta amplitude
    walker_dn = wave_data["mo_A2B"] @ walker_dn
    green_b = (
        walker_dn.dot(jnp.linalg.inv(mo_B.T.conj() @ walker_dn))
    ).T
    green_a, green_b = green_a[:noccA, noccA:], green_b[:noccB, noccB:]
    o0 = _tls_olp(trial,walker_up,walker_dn,wave_data)
    # o1 = jnp.einsum("ia,ia", t1A, green_a) + jnp.einsum("ia,ia", t1B, green_b)
    o2 = (
        0.5 * jnp.einsum("iajb, ia, jb", t2AA, green_a, green_a)
        + 0.5 * jnp.einsum("iajb, ia, jb", t2BB, green_b, green_b)
        + jnp.einsum("iajb, ia, jb", t2AB, green_a, green_b)
    )
    return o2 * o0

@partial(jit, static_argnums=5)
def _ut2_exp1(x: float, h1_mod: jax.Array, walker_up: jax.Array,
                    walker_dn: jax.Array, wave_data: dict, trial):
    '''
    unrestricted <ep(T1)HF|T2 exp(x*h1_mod)|walker>
    '''
    walker_up_1x = walker_up + x * h1_mod.dot(walker_up)
    walker_dn_1x = walker_dn + x * h1_mod.dot(walker_dn)
    
    olp = _ut2_walker_olp(trial,walker_up_1x, walker_dn_1x, wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    return olp/o0

@partial(jit, static_argnums=5)
def _ut2_exp2(x: float, chol_i: jax.Array, walker_up: jax.Array,
                   walker_dn: jax.Array, wave_data: dict, trial) -> complex:
    '''
    unrestricted <ep(T1)HF|T2 exp(x*h2_mod)|walker>
    '''

    walker_up_2x = (
        walker_up
        + x * chol_i.dot(walker_up)
        + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker_up))
    )
    walker_dn_2x = (
        walker_dn
        + x * chol_i.dot(walker_dn)
        + x**2 / 2.0 * chol_i.dot(chol_i.dot(walker_dn))
    )
    
    olp = _ut2_walker_olp(trial, walker_up_2x,walker_dn_2x,wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)
    
    return olp/o0

@partial(jit, static_argnums=0)
def _calc_energy_pt2(trial, walker_up, walker_dn, ham_data, wave_data, eps = 1e-4):
    '''
    t1 = <exp(T1)HF|walker>/<HF|walker>
    t2 = <exp(T1)HF|T1+T2|walker>/<HF|walker>
    e0 = <exp(T1)HF|h1+h2|walker>/<HF|walker>
    e1 = <exp(T1)HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    '''

    # eps = 1e-4

    norb = trial.norb
    h1_mod = ham_data['h1_mod']
    chol = ham_data["chol"].reshape(-1, norb, norb)

    # e0 = <exp(T1)HF|h1+h2|walker>/<HF|walker> #
    # one body
    x = 0.0
    f1 = lambda a: _tls_exp1(a,h1_mod,walker_up,walker_dn,wave_data,trial)
    t1, d_exp1 = jvp(f1, [x], [1.0])

    # two body
    def scanned_fun(carry, c):
        eps,walker_up,walker_dn,wave_data = carry
        return carry, _tls_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

    _, exp2_p = lax.scan(scanned_fun, (eps,walker_up,walker_dn,wave_data), chol)
    _, exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
    _, exp2_m = lax.scan(scanned_fun, (-1.0 * eps,walker_up,walker_dn,wave_data), chol)
    d2_exp2 = (exp2_p - 2.0 * exp2_0 + exp2_m) / eps / eps

    e0 = (d_exp1 + jnp.sum(d2_exp2) / 2.0 )

    d_exp1, d2_exp2 = None, None
    exp2_p, exp2_0, exp2_m = None, None, None
    
    # e1 = <exp(T1)HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    # one body
    x = 0.0
    f1 = lambda a: _ut2_exp1(a,h1_mod,walker_up,walker_dn,wave_data,trial)
    t2, d_exp1 = jvp(f1, [x], [1.0])

    # two body
    def scanned_fun(carry, c):
        eps,walker_up,walker_dn,wave_data = carry
        return carry, _ut2_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

    _, exp2_p = lax.scan(scanned_fun, (eps,walker_up,walker_dn,wave_data), chol)
    _, exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
    _, exp2_m = lax.scan(scanned_fun, (-1.0 * eps,walker_up,walker_dn,wave_data), chol)
    d2_exp2 = (exp2_p - 2.0 * exp2_0 + exp2_m) / eps / eps

    e1 = (d_exp1 + jnp.sum(d2_exp2) / 2.0 )

    # o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)
    
    # return jnp.real(t1/o0), jnp.real(t2/o0), jnp.real(e0/o0), jnp.real(e1/o0)
    return jnp.real(t1), jnp.real(t2), jnp.real(e0), jnp.real(e1)

In [20]:
o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)
print(o0)

(-0.9999999999108924+0j)


In [10]:
eris = mycc.ao2mo(mycc.mo_coeff)
eccs = mycc.energy(mycc.t1,(0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
eccs = mf.e_tot + eccs
print('ccs energy: ', eccs)
eccsd = mycc.energy(mycc.t1, mycc.t2, eris)
eccsd = mf.e_tot + eccsd
print('ccsd energy: ', eccsd)

ccs energy:  -2.088879997348302
ccsd energy:  -2.165477505313499


In [18]:
walker_up, walker_dn = prop_data['walkers'][0], prop_data['walkers'][0]
t1,t2,e0,e1 = _calc_energy_pt2(trial, walker_up, walker_dn, ham_data, wave_data, 1e-4)
print(t1,t2,e0,e1)
print(h0 + 1/t1 * e0 )
print(h0 + 1/t1 * e0 - eccs)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0 - eccsd)

-0.9942404644877956 2.7592508685257e-12 4.23103655115489 0.0761564459985375
-2.088879855808256
1.4154004590594127e-07
-2.16547746848892
3.6824578941718755e-08


In [13]:
from ad_afqmc import wavefunctions
trial_pt2 = wavefunctions.uccsd_pt2_ad(trial.norb,trial.nelec)
t1,t2,e0,e1 = trial_pt2._calc_energy_pt(walker_up, walker_dn, ham_data, wave_data)
print(t1,t2,e0,e1)
print(h0 + 1/t1 * e0 )
print(h0 + 1/t1 * e0 - eccs)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0 - eccsd)

0.9942404645763899 2.7592508685257e-12 -4.231036551531908 0.07615644599853749
-2.088879855808256
1.4154004590594127e-07
-2.012282243134418
0.1531952621790813
