In [25]:
from functools import partial
from pyscf import gto, scf, cc, ao2mo, df
import numpy as np

from jax import config
config.update("jax_enable_x64", True)

print = partial(print, flush=True)

a = 1.05835 # 2aB
d = 2
atoms = f'''
H     0.000     0.000    -0.370
H     0.000     0.000     0.370
H     {d}       0.000     0.000
H     {0.74+d}  0.000     0.000
'''
frozen=0

mol = gto.M(atom=atoms, basis="ccpvdz", verbose=4)
mol.set_range_coulomb(0)
mol.build()

mf = scf.RHF(mol).density_fit()
e = mf.kernel()

mycc = cc.CCSD(mf,frozen=frozen)
eccsd = mycc.kernel()[0]


System: uname_result(system='Linux', node='yichi-thinkpad', release='4.4.0-26100-Microsoft', version='#1882-Microsoft Fri Jan 01 08:00:00 PST 2016', machine='x86_64')  Threads 12
Python 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]
numpy 1.24.3  scipy 1.14.1  h5py 3.12.1
Date: Mon Aug 25 18:50:24 2025
PySCF version 2.8.0
PySCF path  /home/yichi/research/software/lno_pyscf
GIT HEAD (branch master) ef75f4190e4de208685670651dc6c467f72b6794

[ENV] PYSCF_EXT_PATH /home/yichi/research/software/pyscf
[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 4
[INPUT] num. electrons = 4
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000  -0.370000000000 AA    0.000000000000   0.000000000000  -0.69

******** <class 'pyscf.df.df.DF'> ********
auxbasis = None
max_memory = 4000
Default auxbasis cc-pvdz-jkfit is used for H ccpvdz
init E= -1.64338182871186
  HOMO = -0.499546199307415  LUMO = 0.151279444472866
cycle= 1 E= -2.25072827302064  delta_E= -0.607  |g|= 0.0704  |ddm|=    1
  HOMO = -0.568559035775527  LUMO = 0.186543530604378
cycle= 2 E= -2.25308251303881  delta_E= -0.00235  |g|= 0.0123  |ddm|= 0.0678
  HOMO = -0.561939889043839  LUMO = 0.191781821737803
cycle= 3 E= -2.25316769129707  delta_E= -8.52e-05  |g|= 0.000867  |ddm|= 0.0176
  HOMO = -0.561979249804308  LUMO = 0.191813203066073
cycle= 4 E= -2.25316809947644  delta_E= -4.08e-07  |g|= 4.59e-05  |ddm|= 0.00208
  HOMO = -0.56199275814008  LUMO = 0.19182317909231
cycle= 5 E= -2.25316810080334  delta_E= -1.33e-09  |g|= 3.89e-06  |ddm|= 0.000151
  HOMO = -0.561992915441165  LUMO = 0.191821892643455
cycle= 6 E= -2.25316810081162  delta_E= -8.28e-12  |g|= 3.19e-07  |ddm|= 1.22e-05
  HOMO = -0.56199300539708  LUMO = 0.19182204635

In [3]:
from ad_afqmc.lno.base import lno
from pyscf.lib import logger
import numpy as np
from functools import reduce

_fdot = np.dot
fdot = lambda *args: reduce(_fdot, args)

In [26]:
from ad_afqmc.lno.cc import LNOCCSD
from pyscf.lib import logger
import sys
from ad_afqmc.lno.base import lno
from ad_afqmc.lno_ccsd import lno_ccsd
import numpy as np
from functools import reduce
log = logger.Logger(sys.stdout, 6)

options = {
    "n_eql": 4,
    "n_ene_blocks": 1,
    "n_sr_blocks": 20,
    "n_blocks": 20,
    "n_walkers": 5,
    "seed": 98,
    "trial": "cisd",
    "walker_type": "rhf",
    "dt":0.005,
    "ene0": 0,
}

localize = False
mf_cc = mycc
thresh = 1e-5
# chol_cut = 1e-6
run_frg_list = None
lo_type = 'boys'
no_type = 'ie' # cim
frag_lolist = '1o'

from pyscf.cc.ccsd import CCSD
from pyscf.cc.uccsd import UCCSD
if isinstance(mf_cc, (CCSD, UCCSD)):
    full_cisd = True
else:
    full_cisd = False

mf = mf_cc._scf

if isinstance(thresh, list):
    thresh_occ, thresh_vir = thresh
else:
    thresh_occ = thresh*10
    thresh_vir = thresh

lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
# lno_cc.ccsd_t = True
lno_cc.force_outcore_ao2mo = True

frag_atmlist = lno_cc.frag_atmlist
s1e = lno_cc._scf.get_ovlp()


# log.info('no_type = %s', no_type)

# LO construction
# orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
orbactocc = lno_cc.split_mo()[1] # non-localized active occ
if localize:
    orbloc = lno_cc.get_lo(lo_type=lo_type) # localized active occ orbitals
    m = fdot(orbloc.T, s1e, orbactocc)
    lospanerr = abs(fdot(m.T, m) - np.eye(m.shape[1])).max()
    if lospanerr > 1e-10:
        log.error('LOs do not fully span the occupied space! '
                    'Max|<occ|LO><LO|occ>| = %e', lospanerr)
        raise RuntimeError
    # check Span(LO) == Span(occ)
    occspanerr = abs(fdot(m, m.T) - np.eye(m.shape[0])).max()
    if occspanerr < 1e-10:
        log.info('LOs span exactly the occupied space.')
        if no_type not in ['ir','ie']:
            log.error('"no_type" must be "ir" or "ie".')
            raise ValueError
    else:
        log.info('LOs span occupied space plus some virtual space.')
else:
    orbloc = orbactocc
# LO assignment to fragments

if frag_lolist == '1o':
    log.info('Using single-LO fragment') # this is what we use, every active local occ labels a fragment
    frag_lolist = [[i] for i in range(orbloc.shape[1])]
else: print('Only support single LO fragment!')
nfrag = len(frag_lolist)
frag_nonvlist = lno_cc.frag_nonvlist

# dump info
log.info('nfrag = %d  nlo = %d', nfrag, orbloc.shape[1])
log.info('frag_atmlist = %s', frag_atmlist)
log.info('frag_lolist = %s', frag_lolist)
log.info('frag_nonvlist = %s', frag_nonvlist)

if not (no_type[0] in 'rei' and no_type[1] in 'rei'):
    log.warn('Input no_type "%s" is invalid.', no_type)
    raise ValueError

if frag_nonvlist is None: frag_nonvlist = [[None,None]] * nfrag

eris = lno_cc.ao2mo()
frozen_mask = lno_cc.get_frozen_mask()
thresh_pno = [thresh_occ,thresh_vir]
print(f'lno thresh {thresh_pno}')

if run_frg_list is None:
    run_frg_list = range(nfrag)

from jax import random
seeds = random.randint(random.PRNGKey(options["seed"]),
                    shape=(len(run_frg_list),), minval=0, maxval=100000*nfrag)

# for ifrag in run_frg_list:
for ifrag in [0]:
    print(f'########### running fragment {ifrag+1} ##########')
    fraglo = frag_lolist[ifrag]
    orbfragloc = orbloc[:,fraglo] # the specific local active occ
    frag_target_nocc, frag_target_nvir = frag_nonvlist[ifrag]
    THRESH_INTERNAL = 1e-10
    frzfrag, orbfrag, can_orbfrag = lno.make_fpno1(lno_cc, eris, orbfragloc, no_type,
                                                THRESH_INTERNAL, thresh_pno,
                                                frozen_mask=frozen_mask,
                                                frag_target_nocc=frag_target_nocc,
                                                frag_target_nvir=frag_target_nvir,
                                                canonicalize=False)
    
    mol = mf.mol
    nocc = mol.nelectron // 2 
    nao = mol.nao
    actfrag = np.array([i for i in range(nao) if i not in frzfrag])
    frzocc = np.array([i for i in range(nocc) if i in frzfrag])
    actocc = np.array([i for i in range(nocc) if i in actfrag])
    actvir = np.array([i for i in range(nocc,nao) if i in actfrag])
    print(f'# active orbitals: {len(actfrag)} {actfrag}')
    print(f'# active occupied orbitals: {len(actocc)} {actocc}')
    print(f'# active virtual orbitals: {len(actvir)} {actvir}')
    print(f'# frozen orbitals: {len(frzfrag)} {frzfrag}')
    s1e = mf.get_ovlp() if eris is None else eris.s1e
    prjlo = orbfragloc.T@s1e@orbfrag[:,actocc]
        
    # if full_cisd:
    #     prj_mo2no = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag).T
    #     # prj_oo = prj_mo2no[:nocc,:nocc]
    #     # prj_vv = prj_mo2no[nocc:,nocc:]
    #     prj_oo_act = prj_mo2no[actocc,:nocc]
    #     prj_vv_act = prj_mo2no[actvir,nocc:]
    #     full_t1 = mycc.t1
    #     full_t2 = mycc.t2
    #     # t2 in ijab orber
    #     # ci2 in iajb order
    #     t1 = np.einsum("ji,jb,ab->ia",prj_oo_act.T,full_t1,prj_vv_act)
    #     t2 = np.einsum("ki,lj,klcd,ac,bd->ijab",
    #                       prj_oo_act.T,prj_oo_act.T,full_t2,prj_vv_act,prj_vv_act)

    # else:
    #     ecorr_ccsd,t1,t2 \
    #         = lno_ccsd.cc_impurity_solve(
    #             mf,orbfrag,orbfragloc,frozen=frzfrag,eris=eris,log=log)
    #     print(f'# lno-ccsd correlation energy: {ecorr_ccsd}')
    
    # nelec_act = len(actocc)*2
    # norb_act = len(actfrag)
    # print(f'# number of active electrons: {nelec_act}')
    # print(f'# number of active orbitals: {norb_act}')
    # print(f'# number of frozen orbitals: {len(frzfrag)}')

    # ci1 = np.array(t1)        
    # ci2 = t2 + np.einsum("ia,jb->ijab", ci1, ci1)
    # ci2 = ci2.transpose(0, 2, 1, 3)

    # options["seed"] = seeds[ifrag]
    # lno_ccsd.prep_lno_amp_chol_file(
    #     mf,orbfrag,options,norb_act,nelec_act,prjlo=prjlo,
    #     norb_frozen=frzfrag,ci1=ci1,ci2=ci2,use_df_vecs=False)  
    

Using single-LO fragment
nfrag = 2  nlo = 2
frag_atmlist = None
frag_lolist = [[0], [1]]
frag_nonvlist = None
Lov is saved to /tmp/pwifwwhe
lno thresh [0.0001, 1e-05]


########### running fragment 1 ##########
# active orbitals: 18 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17]
# active occupied orbitals: 2 [0 1]
# active virtual orbitals: 16 [ 2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17]
# frozen orbitals: 2 [18 19]


In [28]:
frz_mo = np.where(np.array(frozen_mask) == False)[0]
print(frz_mo)

[]


In [29]:
full_t1 = mycc.t1
full_t2 = mycc.t2
print(full_t1.shape)
print(full_t2.shape)

(2, 18)
(2, 2, 18, 18)


In [30]:
act_mo_occ = [i for i in range(nocc) if i not in frz_mo]
act_mo_vir = [i for i in range(nocc,nao) if i not in frz_mo]
print(act_mo_occ)
print(act_mo_vir)

[0, 1]
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]


In [43]:
print(frz_mo)
act_mo_occ = np.array([i for i in range(nocc) if i not in frz_mo])
act_mo_vir = np.array([i for i in range(nocc,nao) if i not in frz_mo])
prj_no2mo = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag)
prj_oo_act = prj_no2mo[np.ix_(act_mo_occ,actocc)]
prj_vv_act = prj_no2mo[np.ix_(act_mo_vir,actvir)]
print(prj_oo_act.shape)
print(prj_vv_act.shape)


[]
(2, 2)
(18, 16)


In [None]:
prj_oo_act = prj_no2mo[act_mo_occ,actocc]

In [36]:
print(prj_no2mo[1,])

[-1.00000000e+00 -1.32706346e-16  8.32667256e-17 -2.32591668e-16
 -7.09120179e-19  3.93849547e-16 -8.99127326e-18 -2.77555760e-17
 -5.93491359e-16 -1.73472345e-17 -2.01856265e-17 -3.04392705e-17
  2.77555756e-17 -1.10583490e-17 -1.87350135e-16 -2.08166817e-16
  5.48172618e-16  4.61274795e-17  2.77555756e-16 -1.83607205e-16]


In [32]:
act_mo_occ = np.array([i for i in range(nocc) if i not in frz_mo])
act_mo_vir = np.array([i for i in range(nocc,nao) if i not in frz_mo])
prj_no2mo = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag)
print(prj_no2mo.shape)
prj_oo_act = prj_no2mo[act_mo_occ,actocc]
prj_vv_act = prj_no2mo[act_mo_vir,actvir]
print(prj_oo_act.shape)
print(prj_vv_act.shape)

(20, 20)


ValueError: Incompatible shapes for broadcasting: shapes=[(18,), (16,)]

In [None]:
prj_no2mo = lno_ccsd.no2mo(mf.mo_coeff,s1e,orbfrag)
prj_oo_act = prj_no2mo[frozen:nocc,actocc]
prj_vv_act = prj_no2mo[nocc:,actvir]
full_t1 = mycc.t1
full_t2 = mycc.t2
# t2 in ijab orber
# ci2 in iajb order
t1 = np.einsum("ij,ia,ba->jb",prj_oo_act,full_t1,prj_vv_act.T)
t2 = np.einsum("ik,jl,ijab,db,ca->klcd",
                    prj_oo_act,prj_oo_act,full_t2,prj_vv_act.T,prj_vv_act.T)

In [45]:
mc_no = cc.CCSD(mf,mo_coeff=orbfrag)
mc_no.kernel(t1=t1, t2=t2)[0]


******** <class 'pyscf.cc.dfccsd.RCCSD'> ********
CC2 = 0
CCSD nocc = 5, nmo = 50
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-05
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 4000 MB (current use 376 MB)
Init E_corr(RCCSD) = -0.232309707392391
cycle = 1  E_corr(RCCSD) = -0.232309697434484  dE = 9.95790714e-09  norm(t1,t2) = 2.03184e-07
RCCSD converged
E(RCCSD) = -5.587691364956932  E_corr = -0.2323096974344841


-0.23230969743448412

In [38]:
print(t1.shape)
print(t2.shape)

(4, 18)
(4, 4, 18, 18)


In [73]:
full_t1 = mycc.t1
full_t2 = mycc.t2
full_t1_no = np.einsum("ji,jb,ab->ia",prj_oo.T,full_t1,prj_vv)
full_t2_no = np.einsum("ki,lj,klcd,ac,bd->ijab",prj_oo.T,prj_oo.T,full_t2,prj_vv,prj_vv)
full_ci1 = np.array(full_t1_no)
full_ci2 = full_t2_no + np.einsum("ia,jb->ijab", full_ci1, full_ci1)
full_ci2 = full_ci2.transpose(0, 2, 1, 3)
mc_no = cc.CCSD(mf,mo_coeff=orbfrag)
mc_no.kernel(t1=full_t1_no, t2=full_t2_no)[0]


******** <class 'pyscf.cc.dfccsd.RCCSD'> ********
CC2 = 0
CCSD nocc = 4, nmo = 40
max_cycle = 50
direct = 0
conv_tol = 1e-07
conv_tol_normt = 1e-05
diis_space = 6
diis_start_cycle = 0
diis_start_energy_diff = 1e+09
max_memory 4000 MB (current use 819 MB)
Init E_corr(RCCSD) = -0.173902165045132
cycle = 1  E_corr(RCCSD) = -0.173902188225659  dE = -2.31805262e-08  norm(t1,t2) = 4.5547e-07
RCCSD converged
E(RCCSD) = -4.505126391785148  E_corr = -0.1739021882256587


-0.17390218822565867

In [74]:
from ad_afqmc import config, wavefunctions
from jax import numpy as jnp
from jax import random, jit, vmap, jvp, lax
from ad_afqmc.lno_ccsd import lno_ccsd
import jax

config.setup_jax()
MPI = config.setup_comm()

ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ = (
    lno_ccsd.prep_lnoccsd_afqmc())

comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Process rank
size = comm.Get_size() 

seed = options["seed"]
propagator = prop
init_walkers = None
ham_data = wavefunctions.rhf(trial.norb, trial.nelec,n_batch=trial.n_batch
                                )._build_measurement_intermediates(ham_data, wave_data)
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, propagator, trial, wave_data)

prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]

init_walker = prop_data["walkers"][0]

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 12
# nelec: (2, 2)
# nchol: 184
#
# n_eql: 4
# n_ene_blocks: 1
# n_sr_blocks: 20
# n_blocks: 20
# n_walkers: 5
# seed: 296021
# trial: cisd
# walker_type: rhf
# dt: 0.005
# ene0: 0
# n_exp_terms: 6
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# free_projection: False
# n_batch: 1
# use_gpu: False
#


In [75]:
nocc = mol.nelectron // 2 
nao = mol.nao
print(nao)
actfrag = np.array(
    [i for i in range(nao) if i not in frzfrag])
print(frzfrag)
print(actfrag)
frzocc = np.array(
    [i for i in range(nocc) if i in frzfrag])
actocc = np.array(
    [i for i in range(nocc) if i in actfrag])
print(frzocc)
print(actocc)

40
[ 0  1 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
 36 37 38 39]
[ 2  3  4  5  6  7  8  9 10 11 12 13]
[0 1]
[2 3]


In [43]:
def no2mo(mo_coeff,s1e,no_coeff):
    prj = mo_coeff.T@s1e@no_coeff
    return prj

def prj_walker(projector,walker,frzocc,actfrag):
    p_act = projector[:,actfrag]
    walker_act = p_act@walker
    if len(frzocc) != 0:
        p_frzocc = projector[:,frzocc]
        walker_new = jnp.hstack((p_frzocc,walker_act))
    else:
        walker_new = walker_act

    return walker_new


In [18]:
def cisd_walker_overlap(walker,ci1,ci2):
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return (1.0 + 2 * o1 + o2) * o0

def cisd_walker_overlap_ratio(walker,ci1,ci2):
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    # o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return 1/(1.0 + 2 * o1 + o2)

def full_cisd_walker_overlap(walker,ci1,ci2):
    
    nocc = walker.shape[1]
    GF = (walker.dot(jnp.linalg.inv(walker[: nocc, :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    o1 = jnp.einsum("ia,ia", ci1, GF[:, nocc:])
    o2 = 2 * jnp.einsum(
        "iajb, ia, jb", ci2, GF[:, nocc:], GF[:, nocc:]
    ) - jnp.einsum("iajb, ib, ja", ci2, GF[:, nocc:], GF[:, nocc:])
    return (1.0 + 2 * o1 + o2) * o0

In [None]:
def _frg_mod_ccsd_olp2(walker: jax.Array, wave_data: dict) -> complex:
    '''
    <psi_ccsd|walker>=<psi_0|walker>+C_ia^*G_ia+C_iajb^*(G_iaG_jb-G_ibG_ja)
    modified CCSD overlap returns the second and the third term
    that is, the overlap of the walker with the CCSD wavefunction
    without the hartree-fock part
    and skip one sum over the occ
    '''
    # prjlo = wave_data["prjlo"].reshape(walker.shape[1])
    m = jnp.dot(wave_data["prjlo"].T,wave_data["prjlo"])
    nocc, ci1, ci2 = walker.shape[1], wave_data["full_ci1"], wave_data["full_ci2"]
    gf = (walker.dot(jnp.linalg.inv(walker[: walker.shape[1], :]))).T
    o0 = jnp.linalg.det(walker[: nocc, :]) ** 2
    
    o1 = jnp.einsum("ia,ka,ik->", ci1, gf[:, nocc:],m)
    o2 = 2 * jnp.einsum("iajb,ka,jb,ik->", ci2, gf[:, nocc:], gf[:, nocc:],m) \
        - jnp.einsum("iajb,kb,ja,ik->", ci2, gf[:, nocc:], gf[:, nocc:],m)
    olp = (2*o1+o2)*o0
    # olp_i = jnp.einsum("i,i->", olp, pick_i)
    return olp

In [76]:
prop_data, (blk_e, blk_wt)\
    = sampler._block_scan(prop_data,None,ham_data,prop,trial,wave_data)
print(blk_e,blk_wt)

-4.39617965082072 5.006703695239711


In [77]:
walker = prop_data["walkers"][0]
print(walker)
print(walker.shape)

[[-2.73884231e-02+0.00372321j  9.91077098e-01-0.0977019j ]
 [-9.58816243e-01-0.10180332j -1.67284583e-02-0.00692188j]
 [ 5.20185869e-03-0.04353126j -4.70182700e-03-0.05134089j]
 [ 2.84551884e-03+0.05418152j -4.04917627e-04-0.02912197j]
 [ 1.97660352e-02+0.04373476j -1.77426500e-03-0.02989733j]
 [ 1.72833437e-02-0.10546156j -1.47756918e-03-0.01604573j]
 [ 2.12459380e-02-0.13308867j -2.03663740e-04-0.02754648j]
 [-1.22835740e-02+0.10767316j  6.81215573e-03+0.02228497j]
 [ 1.52360158e-02-0.02056734j -4.09440182e-03-0.03186463j]
 [-3.91432851e-02+0.07184061j -7.09684967e-03+0.02024532j]
 [ 1.28483929e-02-0.11148736j -2.28841335e-03-0.01469083j]
 [ 2.19263097e-02+0.02492758j -1.28011057e-02+0.00712175j]]
(12, 2)


In [78]:
prj_w = np.eye(nao)
walker_t = prj_walker(prj_w, walker, frzocc, actfrag)
print(walker_t)
print(walker_t.shape)

[[ 1.00000000e+00+0.j          0.00000000e+00+0.j
   0.00000000e+00+0.j          0.00000000e+00+0.j        ]
 [ 0.00000000e+00+0.j          1.00000000e+00+0.j
   0.00000000e+00+0.j          0.00000000e+00+0.j        ]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
  -2.73884231e-02+0.00372321j  9.91077098e-01-0.0977019j ]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
  -9.58816243e-01-0.10180332j -1.67284583e-02-0.00692188j]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
   5.20185869e-03-0.04353126j -4.70182700e-03-0.05134089j]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
   2.84551884e-03+0.05418152j -4.04917627e-04-0.02912197j]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
   1.97660352e-02+0.04373476j -1.77426500e-03-0.02989733j]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
   1.72833437e-02-0.10546156j -1.47756918e-03-0.01604573j]
 [ 0.00000000e+00+0.j          0.00000000e+00+0.j
   2.12459380e-02-0.13308867j -2.03663740e-04-0.02754648j]
 [ 0.00000000e+00+0

In [79]:
full_olp = cisd_walker_overlap(walker_t, full_ci1, full_ci2)
olp = trial._calc_overlap_restricted(walker, wave_data)
print(full_olp)
print(olp)

(0.9250873359675061+0.009700266067914782j)
(0.9250873359675064+0.009700266067914817j)


In [None]:
## make ccsd lno ###

In [80]:
lno_cc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
lno_cc.thresh_occ = thresh_occ
lno_cc.thresh_vir = thresh_vir
lno_cc.lo_type = lo_type
lno_cc.no_type = no_type
lno_cc.frag_lolist = frag_lolist
lno_cc.ccsd_t = True
lno_cc.force_outcore_ao2mo = True
orbocc0, orbocc1, orbvir1, orbvir0 = lno_cc.split_mo()

In [83]:
print(orbocc1.shape)

(40, 4)


In [86]:
no_type[1]

'e'

In [None]:
def make_fpno1(mfcc, eris, orbfragloc, no_type, thresh_internal, thresh_external,
               frag_target_nocc=None, frag_target_nvir=None,canonicalize=True):
    log = logger.Logger(mfcc.stdout, mfcc.verbose)
    mf = mfcc._scf
    nocc = np.count_nonzero(mf.mo_occ>1e-10)
    nmo = mf.mo_occ.size
    orbocc0, orbocc1, orbvir1, orbvir0 = mfcc.split_mo() # frz_occ, act_occ, act_vir, frz_vir
    moeocc0, moeocc1, moevir1, moevir0 = mfcc.split_moe() # split energy

    s1e = eris.s1e # if eris.s1e is None else mf.get_ovlp()
    fock = eris.fock # if eris.fock is None else mf.get_fock()
    Lov = eris.Lov
    # chosen loc_orb overlap with act_vir
    lovir = abs(fdot(orbfragloc.T, s1e, orbvir1)).max() > 1e-10

    if isinstance(thresh_external, float):
        thresh_ext_occ = thresh_ext_vir = thresh_external
    else:
        thresh_ext_occ, thresh_ext_vir  = thresh_external

    # sanity check for no_type:
    if not lovir and no_type[0] != 'i':
        log.warn('Input LOs span only occ but input no_type[0] is not "i".')
        raise ValueError
    if not lovir and no_type[1] == 'i':
        log.warn('Input LOs span only occ but input no_type[1] is "i".')
        raise ValueError

    # split active occ/vir into internal(1) and external(2)
    m = fdot(orbfragloc.T, s1e, orbocc1) # overlap with all loc act_occs
    uocc1, uocc2 = lno.projection_construction(m, thresh_internal)
    moefragocc1, orbfragocc1 = lno.subspace_eigh(fock, fdot(orbocc1, uocc1))
    if lovir:
        m = fdot(orbfragloc.T, s1e, orbvir1)
        uvir1, uvir2 = lno.projection_construction(m, thresh_internal)
        moefragvir1, orbfragvir1 = lno.subspace_eigh(fock, fdot(orbvir1, uvir1))

    def moe_Ov(moefragocc):
        return (moefragocc[:,None] - moevir1).reshape(-1)
    def moe_oV(moefragvir):
        return (moeocc1[:,None] - moefragvir).reshape(-1)
    eov = moe_Ov(moeocc1)
    # Construct PT2 dm_vv
    if no_type[1] == 'r':   # OvOv: IaJc,IbJc->ab
        u = fdot(orbocc1.T, s1e, orbfragocc1)
        ovov = eris.get_OvOv(u)
        eia = ejb = moe_Ov(moefragocc1)
        e1_or_e2 = 'e1'
        swapidx = 'ab'
    elif no_type[1] == 'e': # Ovov: Iajc,Ibjc->ab
        u = fdot(orbocc1.T, s1e, orbfragocc1)
        ovov = eris.get_Ovov(u)
        eia = moe_Ov(moefragocc1)
        Ljb = Lov
        ejb = eov
        e1_or_e2 = 'e1'
        swapidx = 'ab'
    else:                   # oVov: iCja,iCjb->ab
        u = fdot(orbvir1.T, s1e, orbfragvir1)
        ovov = eris.get_oVov(u)
        eia = moe_oV(moefragvir1)
        Ljb = Lov
        ejb = eov
        e1_or_e2 = 'e2'
        swapidx = 'ij'

    eiajb = (eia[:,None]+ejb).reshape(*ovov.shape)
    t2 = ovov / eiajb

    dmvv = lno.make_rdm1_mp2(t2, 'vv', e1_or_e2, swapidx)
   
    if lovir:
        dmvv = fdot(uvir2.T, dmvv, uvir2)

    Lia = Ljb = ovov = eiajb = None
    # Construct PT2 dm_oo
    if no_type in ['ie','ei']: # ie/ei share same t2
        if no_type[0] == 'e':   # oVov: iAkb,jAkb->ij
            e1_or_e2 = 'e1'
            swapidx = 'ij'
        else:                   # Ovov: Kaib,Kajb->ij
            e1_or_e2 = 'e2'
            swapidx = 'ab'
    else:
        t2 = None

        if no_type[0] == 'r':   # oVoV: iAkB,jAkB->ij
            u = fdot(orbvir1.T, s1e, orbfragvir1)
            ovov = eris.get_oVoV(u)
            eia = ejb = moe_oV(moefragvir1)
            e1_or_e2 = 'e1'
            swapidx = 'ab'
        elif no_type[0] == 'e': # oVov: iAkb,jAkb->ij
            u = fdot(orbvir1.T, s1e, orbfragvir1)
            ovov = eris.get_oVov(u)
            eia = moe_oV(moefragvir1)
            Ljb = Lov
            ejb = eov
            e1_or_e2 = 'e1'
            swapidx = 'ij'
        else:                   # Ovov: Kaib,Kajb->ij
            u = fdot(orbocc1.T, s1e, orbfragocc1)
            ovov = eris.get_Ovov(u)
            eia = moe_Ov(moefragocc1)
            Ljb = Lov
            ejb = eov
            e1_or_e2 = 'e2'
            swapidx = 'ab'

        eiajb = (eia[:,None]+ejb).reshape(*ovov.shape)
        t2 = ovov / eiajb

        Lia = Ljb = ovov = eiajb = None

    dmoo = make_rdm1_mp2(t2, 'oo', e1_or_e2, swapidx)
    dmoo = fdot(uocc2.T, dmoo, uocc2)

    t2 = None
    # Compress external space by PNO
    
    if frag_target_nocc is not None: frag_target_nocc -= orbfragocc1.shape[1]
    orbfragocc2, orbfragocc0 = natorb_compression(dmoo, orbocc1, thresh_ext_occ,
                                                  uocc2, frag_target_nocc)
#    if (canonicalize): orbfragocc12 = subspace_eigh(fock, np.hstack([orbfragocc2, orbfragocc1]))[1]
#    else: orbfragocc12 = np.hstack([orbfragocc2, orbfragocc1])
    can_orbfragocc12 = subspace_eigh(fock, np.hstack([orbfragocc2, orbfragocc1]))[1]
    orbfragocc12 = np.hstack([orbfragocc2, orbfragocc1])
    if lovir:
        
        if frag_target_nvir is not None: frag_target_nvir -= orbfragvir1.shape[1]
        orbfragvir2, orbfragvir0 = natorb_compression(dmvv, orbvir1, thresh_ext_vir,
                                                      uvir2, frag_target_nvir)
        #if (canonicalize): orbfragvir12 = subspace_eigh(fock, np.hstack([orbfragvir2, orbfragvir1]))[1]
        #else: orbfragvir12 = np.hstack([orbfragvir2, orbfragvir1])
        can_orbfragvir12 = subspace_eigh(fock, np.hstack([orbfragvir2, orbfragvir1]))[1]
        orbfragvir12 = np.hstack([orbfragvir2, orbfragvir1])
    else: 
        orbfragvir2, orbfragvir0 = natorb_compression(dmvv, orbvir1, thresh_ext_vir,
                                                      None, frag_target_nvir)
        #if (canonicalize): orbfragvir12 = subspace_eigh(fock, orbfragvir2)[1]
        #else: orbfragvir12 = orbfragvir2
        can_orbfragvir12 = subspace_eigh(fock, orbfragvir2)[1]
        orbfragvir12 = orbfragvir2

    orbfrag = np.hstack([orbocc0, orbfragocc0, orbfragocc12,
                         orbfragvir12, orbfragvir0, orbvir0])
    can_orbfrag = np.hstack([orbocc0, orbfragocc0, can_orbfragocc12,
                        can_orbfragvir12, orbfragvir0, orbvir0])
    
    frzfrag = np.hstack([np.arange(orbocc0.shape[1]+orbfragocc0.shape[1]),
                         np.arange(nocc+orbfragvir12.shape[1],nmo)])

    #return frzfrag, orbfrag
    #import pdb;pdb.set_trace()
    if(canonicalize==True): return frzfrag, can_orbfrag, can_orbfrag 
    else: return frzfrag, orbfrag , can_orbfrag