In [1]:
from pyscf import gto, scf
#import os
import numpy as np
from matplotlib import pyplot as plt

In [4]:
from functools import partial
from typing import List, Optional, Union

import jax
import jax.numpy as jnp
import numpy as np
from jax import random

from typing import Any, Tuple

import jax
import jax.numpy as jnp
from jax import checkpoint, jit, lax, random

from ad_afqmc import linalg_utils, sampling
from ad_afqmc.hamiltonian import hamiltonian
from ad_afqmc.propagation import propagator
from ad_afqmc.wavefunctions import wave_function

sampler_eq = sampling.sampler(n_prop_steps=50, n_ene_blocks=5, n_sr_blocks=10)

def init_prop(ham_data, ham, prop, trial, wave_data, options, MPI):
    comm = MPI.COMM_WORLD
    #size = comm.Get_size()
    rank = comm.Get_rank()
    seed = options["seed"]
    #neql = options["n_eql"]
    init_walkers: Optional[Union[List, jax.Array]] = None
    trial_rdm1 = trial.get_rdm1(wave_data)
    if "rdm1" not in wave_data:
        wave_data["rdm1"] = trial_rdm1
    ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
    ham_data = ham.build_propagation_intermediates(ham_data, prop, trial, wave_data)
    prop_data = prop.init_prop_data(trial, wave_data, ham_data, init_walkers)
    prop_data["key"] = random.PRNGKey(seed + rank)
    prop_data["n_killed_walkers"] = 0
    print(f"# initial energy: {prop_data['e_estimate']:.9e}")
    
    return prop_data, ham_data

def en_samples(prop_data,ham_data,prop,trial,wave_data):
    '''return the energies of all walkers in a sample'''
    energy_samples = jnp.real(
        trial.calc_energy(prop_data["walkers"], ham_data, wave_data)
    )
    energy_samples = jnp.where(
        jnp.abs(energy_samples - prop_data["e_estimate"]) > jnp.sqrt(2.0 / prop.dt),
        prop_data["e_estimate"],
        energy_samples,
    )
    return energy_samples

def block_en_weight(prop_data,ham_data,prop,trial,wave_data):
    '''return the energy and weight of an entire sample'''
    energy_samples = en_samples(prop_data,ham_data,prop,wave_data,trial)

    block_weight = jnp.sum(prop_data["weights"])
    block_energy = jnp.sum(energy_samples * prop_data["weights"]) / block_weight
    return block_energy, block_weight

def field_block_scan(
        prop_data: dict,
        fields,
        ham_data: dict,
        prop: propagator,
        trial: wave_function,
        wave_data: dict,
        ) -> Tuple[dict, Tuple[jax.Array, jax.Array]]:
    """Block scan function for a given field"""
    _step_scan_wrapper = lambda x, y: sampler_eq._step_scan(
        x, y, ham_data, prop, trial, wave_data
    )
    prop_data, _ = lax.scan(_step_scan_wrapper, prop_data, fields)
    prop_data["n_killed_walkers"] += prop_data["weights"].size - jnp.count_nonzero(
        prop_data["weights"]
    )
    prop_data = prop.orthonormalize_walkers(prop_data)
    prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)

    block_energy,_ = block_en_weight(prop_data,ham_data,prop,wave_data,trial)
    prop_data["pop_control_ene_shift"] = (
        0.9 * prop_data["pop_control_ene_shift"] + 0.1 * block_energy
    )
    return prop_data

def cs_block_scan(
        prop_data1: dict,
        ham_data1: dict,
        prop1: propagator,
        trial1: wave_function,
        wave_data1: dict,
        prop_data2: dict,
        ham_data2: dict,
        prop2: propagator,
        trial2: wave_function,
        wave_data2: dict):
    '''correlated sampling of two systems over the same field'''
    prop_data1["key"], subkey1 = random.split(prop_data1["key"])
    fields = random.normal(
        subkey1,
        shape=(
            sampler_eq.n_prop_steps,
            prop1.n_walkers,
            ham_data1["chol"].shape[0],
        )
    )
    prop_data1 = field_block_scan(prop_data1,fields,ham_data1,prop1,trial1,wave_data1)
    prop_data2 = field_block_scan(prop_data2,fields,ham_data2,prop2,trial2,wave_data2)

    return prop_data1, prop_data2, fields

def ucs_block_scan(
        prop_data1: dict,
        ham_data1: dict,
        prop1: propagator,
        trial1: wave_function,
        wave_data1: dict,
        prop_data2: dict,
        ham_data2: dict,
        prop2: propagator,
        trial2: wave_function,
        wave_data2: dict):
    '''correlated sampling of two blocks over the same field'''
    prop_data1["key"], subkey1 = random.split(prop_data1["key"])
    fields1 = random.normal(
        subkey1,
        shape=(
            sampler_eq.n_prop_steps,
            prop1.n_walkers,
            ham_data1["chol"].shape[0],
        )
    )
    prop_data1 = field_block_scan(prop_data1,fields1,ham_data1,prop1,trial1,wave_data1)

    prop_data2["key"], subkey2 = random.split(prop_data2["key"])
    fields2 = random.normal(
        subkey2,
        shape=(
            sampler_eq.n_prop_steps,
            prop2.n_walkers,
            ham_data2["chol"].shape[0],
        )
    )
    prop_data2 = field_block_scan(prop_data2,fields2,ham_data2,prop2,trial2,wave_data2)

    return prop_data1, prop_data2, fields1, fields2

In [5]:
from pyscf import gto, scf
basis = 'sto6g'

d1 = 4
atom1 = f'''
H 0 0 0
H {d1} 0 0
'''

mol1 = gto.Mole(
verbose=3,
atom=atom1,
basis=basis,
)

mol1.build()
mf1 = scf.RHF(mol1).density_fit()
mf1.kernel()

d2 = 5
atom2 = f'''
H 0 0 0
H {d2} 0 0
'''

mol2 = gto.Mole(
verbose=3,
atom=atom2,
basis=basis,
)

mol2.build()
mf2 = scf.RHF(mol2).density_fit()
mf2.kernel()

print('the rhf energy difference is: ',mf1.e_tot-mf2.e_tot)


WARN: Even tempered Gaussians are generated as DF auxbasis for  H

converged SCF energy = -0.624498144373364

WARN: Even tempered Gaussians are generated as DF auxbasis for  H

converged SCF energy = -0.608011113059654
the rhf energy difference is:  -0.016487031313709766


In [6]:
from ad_afqmc import pyscf_interface, driver, mpi_jax

mo_file1="mo1.npz"
amp_file1="amp1.npz"
chol_file1="chol1"
pyscf_interface.prep_afqmc(mf1,mo_file=mo_file1,amp_file=amp_file1,chol_file=chol_file1)
mo_file2="mo2.npz"
amp_file2="amp2.npz"
chol_file2="chol2"
pyscf_interface.prep_afqmc(mf2,mo_file=mo_file2,amp_file=amp_file2,chol_file=chol_file2)

# Hostname: YICHI
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
#
# Preparing AFQMC calculation
# Calculating Cholesky integrals
# Decomposing ERI with DF
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (1, 1)
# Number of basis functions: 2
# Number of Cholesky vectors: 18
#
#
# Preparing AFQMC calculation
# Calculating Cholesky integrals
# Decomposing ERI with DF
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (1, 1)
# Number of basis functions: 2
# Number of Cholesky vectors: 18
#


In [7]:
options1 = {
    "dt": 0.005,
    "n_eql": 4,
    "n_ene_blocks": 1,
    "n_sr_blocks": 10,
    "n_blocks": 200,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "rhf",
    "trial": "rhf",
}

options2 = {
    "dt": 0.005,
    "n_eql": 4,
    "n_ene_blocks": 1,
    "n_sr_blocks": 10,
    "n_blocks": 200,
    "n_walkers": 50,
    "seed": 2,
    "walker_type": "rhf",
    "trial": "rhf",
}

In [None]:
from mpi4py import MPI
import numpy as np
from ad_afqmc.corr_sample_test import corr_sample

comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Process rank
size = comm.Get_size()  # Total number of processes
print(f'# rank = {rank} size = {size}')

# Create data in the root process
ham_data1, ham1, prop1, trial1, wave_data1, sampler1, observable1, options1,_ \
    = mpi_jax._prep_afqmc(options1,mo_file=mo_file1,amp_file=amp_file1,chol_file=chol_file1)
ham_data2, ham2, prop2, trial2, wave_data2, sampler2, observable2, options2,_ \
    = mpi_jax._prep_afqmc(options2,mo_file=mo_file2,amp_file=amp_file2,chol_file=chol_file2)
prop_data1, ham_data1 = corr_sample.init_prop(ham_data1, ham1, prop1, trial1, wave_data1, options1, MPI)
prop_data2, ham_data2 = corr_sample.init_prop(ham_data2, ham2, prop2, trial2, wave_data2, options2, MPI)
#prop_data1,prop_data2,field = cs_block_scan(prop_data1,ham_data1,prop1,trial1,wave_data1,prop_data2,ham_data2,prop2,trial2,wave_data2)


# rank = 0 size = 1
# Number of MPI ranks: 1
#
# norb: 2
# nelec: (1, 1)
#
# dt: 0.005
# n_eql: 4
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 200
# n_walkers: 50
# seed: 98
# walker_type: rhf
# trial: rhf
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
# LNO: False
# orbE: 0
# maxError: 0.001
#
# Number of MPI ranks: 1
#
# norb: 2
# nelec: (1, 1)
#
# dt: 0.005
# n_eql: 4
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 200
# n_walkers: 50
# seed: 2
# walker_type: rhf
# trial: rhf
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
# LNO: False
# orbE: 0
# maxError: 0.001
#


In [9]:
prop_data1,prop_data2,fields = cs_block_scan(prop_data1,ham_data1,prop1,trial1,wave_data1,prop_data2,ham_data2,prop2,trial2,wave_data2)
loc_en_sample1 = en_samples(prop_data1,ham_data1,prop1,trial1,wave_data1)
loc_en_sample2 = en_samples(prop_data2,ham_data2,prop2,trial2,wave_data2)

In [15]:
prop_data1["key"], subkey1 = random.split(prop_data1["key"])
print(prop_data1["key"])
print(subkey1)

[2260632897 2682822699]
[ 231470018 1619616203]


In [19]:
key = random.PRNGKey(98 + rank)
print(key)
key,subkey = random.split(key)
print(key,subkey)

[ 0 98]
[ 336490316 3848988999] [3614062411 3294896607]


In [24]:
print(prop_data1["key"])

[2260632897 2682822699]


In [23]:
seed1 = random.randint(random.PRNGKey(0), shape=(10,), minval=0, maxval=100)
print(seed1)

[ 2  4 79 76 54 92 94 82 93 76]


In [31]:
prop_data1["key"] = random.PRNGKey(seed1[5] + rank)
print(prop_data1["key"])

[ 0 92]


In [22]:
#loc_en_sample1 = en_samples(prop_data1,ham_data1,prop1,wave_data1,trial1)
en_sample1 = np.empty(size * len(loc_en_sample1))
print(loc_en_sample1.shape)
print(en_sample1.shape)

(50,)
(200,)


In [8]:
for k,v in prop_data1.items():
    if type(prop_data1[k]) is not int:
        print(k,type(prop_data1[k]),prop_data1[k].shape)
    else: print(k,type(prop_data1[k]),prop_data1[k])

weights <class 'jaxlib.xla_extension.ArrayImpl'> (50,)
walkers <class 'jaxlib.xla_extension.ArrayImpl'> (50, 2, 1)
e_estimate <class 'jaxlib.xla_extension.ArrayImpl'> ()
pop_control_ene_shift <class 'jaxlib.xla_extension.ArrayImpl'> ()
overlaps <class 'jaxlib.xla_extension.ArrayImpl'> (50,)
key <class 'jaxlib.xla_extension.ArrayImpl'> (2,)
n_killed_walkers <class 'int'> 0


In [60]:
for k in prop_data1.keys():
    print(k)

weights
walkers
e_estimate
pop_control_ene_shift
overlaps
key
n_killed_walkers


In [None]:
def gather_prop_data(prop_datas):
    tot_prop_data = {}
    for k in prop_data1.keys():
        if k == 'walkers': 
            tot_prop_data[k] = np.vstack(prop_datas)
            print(tot_prop_data[k].shape)
        else:
            if k == 'n_killed_walkers':
                tot_prop_data[k] = np.sum(prop_datas)
                print(tot_prop_data[k])
            else:
                tot_prop_data[k] = np.hstack(prop_datas)

(100, 2, 1)
0


In [75]:
for k in prop_data1.keys():
    if k == 'n_killed_walkers':
        print(k,prop_data1[k])
    else:
        print(k,prop_data1[k].shape)

weights (50,)
walkers (50, 2, 1)
e_estimate ()
pop_control_ene_shift ()
overlaps (50,)
key (2,)
n_killed_walkers 0


In [79]:
print(prop_data1['e_estimate'])
print(prop_data1['pop_control_ene_shift'])

-0.6244981443733638
-0.6244981443733638


In [76]:
for k in tot_prop_data.keys():
    if k == 'n_killed_walkers':
        print(k,tot_prop_data[k])
    else:
        print(k,tot_prop_data[k].shape)

weights (100,)
walkers (100, 2, 1)
e_estimate (2,)
pop_control_ene_shift (2,)
overlaps (100,)
key (4,)
n_killed_walkers 0


In [80]:
print(tot_prop_data['e_estimate'])
print(tot_prop_data['pop_control_ene_shift'])

[-0.62449814 -0.60801111]
[-0.62449814 -0.60801111]
