In [1]:
from ad_afqmc import mpi_jax, config, driver
import pickle
import time
from functools import partial
from typing import List, Optional, Union, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from jax import dtypes, jvp, random, vjp

from ad_afqmc import hamiltonian, propagation, sampling, stat_utils, wavefunctions
from ad_afqmc.corr_sample_test import corr_sample

print = partial(print, flush=True)

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [2]:
ham_data1,ham1,prop1,trial1,wave_data1,sampler1,observable1,options1,_=\
    mpi_jax._prep_afqmc(option_file='option1.bin',mo_file='mo1.npz',chol_file='chol1')
ham_data2,ham2,prop2,trial2,wave_data2,sampler2,observable2,options2,_=\
    mpi_jax._prep_afqmc(option_file='option2.bin',mo_file='mo2.npz',chol_file='chol2')

# Number of MPI ranks: 1
#
# Using Local Natural Orbital Approximation
# norb: 13
# nelec: (4, 4)
#
# n_eql: 1
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 100
# n_walkers: 20
# seed: 98
# walker_type: rhf
# trial: rhf
# dt: 0.01
# orbE: -2
# prjlo: [[-8.32667268e-17  0.00000000e+00  1.38777878e-17  1.00000000e+00]]
# maxError: 0.0001
# LNO: True
# n_exp_terms: 6
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
#
# Number of MPI ranks: 1
#
# Using Local Natural Orbital Approximation
# norb: 13
# nelec: (4, 4)
#
# n_eql: 1
# n_ene_blocks: 1
# n_sr_blocks: 10
# n_blocks: 100
# n_walkers: 20
# seed: 98
# walker_type: rhf
# trial: rhf
# dt: 0.01
# orbE: -2
# prjlo: [[ 5.55111512e-17 -5.55111512e-17 -1.38777878e-17  1.00000000e+00]]
# maxError: 0.0001
# LNO: True
# n_exp_terms: 6
# n_prop_steps: 50
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# 

In [3]:
config.setup_jax()
MPI = config.setup_comm()

# Hostname: yichi-thinkpad
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [None]:
# def LNOafqmc(
#     ham_data: dict,
#     ham: hamiltonian.hamiltonian,
#     propagator: propagation.propagator,
#     trial: wavefunctions.wave_function,
#     wave_data: dict,
#     sampler: sampling.sampler,
#     observable,
#     options: dict,
#     MPI,
#     init_walkers: Optional[Sequence] = None,
# ):
def init_prop_lno(
        ham_data: dict,
        ham: hamiltonian.hamiltonian,
        propagator: propagation.propagator,
        trial: wavefunctions.wave_function,
        wave_data: dict,
        options: dict,
        MPI
):
    
    init_walkers: Optional[Union[List, jax.Array]] = None
    init = time.time()
    comm = MPI.COMM_WORLD
    # size = comm.Get_size()
    rank = comm.Get_rank()
    seed = options["seed"]
    # neql = options["n_eql"]
    # orbE = options['orbE']

    # if observable is not None:
    #     observable_op = jnp.array(observable[0])
    #     observable_constant = observable[1]
    # else:
    #     observable_op = jnp.array(ham_data["h1"])
    #     observable_constant = 0.0

    trial_rdm1 = trial.get_rdm1(wave_data)
    # trial_observable = np.sum(trial_rdm1 * observable_op)

    if "rdm1" not in wave_data:
        wave_data["rdm1"] = trial_rdm1
    ham_data = ham.build_measurement_intermediates(ham_data,trial,wave_data)
    ham_data = ham.build_propagation_intermediates(ham_data,propagator,trial, wave_data)
    prop_data = propagator.init_prop_data(trial, wave_data, ham_data, init_walkers)
    if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
        raise ValueError(
            "Initial overlaps are zero. Pass walkers with non-zero overlap."
        )
    prop_data["key"] = random.PRNGKey(seed + rank)
    prop_data["n_killed_walkers"] = 0

    comm.Barrier()
    init_time = time.time() - init
    if rank == 0:
        print(f"# initial energy: {prop_data['e_estimate']:.6f} time: {init_time:.2f}s ")
    comm.Barrier()
    return prop_data, ham_data

sampler_eq = sampling.sampler(n_prop_steps=50, n_ene_blocks=5, n_sr_blocks=10)
orbE = options1['orbE']

In [None]:
prop_data1_init,ham_data1_init = init_prop_lno(ham_data1,ham1,prop1,trial1,wave_data1,options1,MPI)
prop_data2_init,ham_data2_init = init_prop_lno(ham_data2,ham2,prop2,trial2,wave_data2,options2,MPI)

# initial energy: -152.062510 time: 0.80s 
# initial energy: -152.053292 time: 0.81s 


In [None]:
init_time = time.time()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
rlx_steps = 0
(prop_data1_rlx,prop_data2_rlx),(loc_en1,loc_orb_en1,loc_wt1,loc_en2,loc_orb_en2,loc_wt2) \
    = corr_sample.lno_cs_steps_scan(rlx_steps,
                      prop_data1_init,ham_data1,prop1,trial1,wave_data1,-2,
                      prop_data2_init,ham_data2,prop2,trial2,wave_data2,-2,
                      )

In [18]:
comm.Barrier()
if rank == 0:
    print('# rlx_step \t system1_en \t system2_en \t en_diff \t orb1_en \t orb2_en \t orb_en_diff')
    print(f'    {0}'
          f'\t \t {prop_data1_init["e_estimate"]:.6f}' 
          f'\t {prop_data2_init["e_estimate"]:.6f}'
          f'\t {prop_data1_init["e_estimate"]-prop_data2_init["e_estimate"]:.6f}')
comm.Barrier()

comm.Barrier()
if rank == 0:
    glb_en1 = np.empty(size*loc_en1.size,dtype='float32')
    glb_en2 = np.empty(size*loc_en2.size,dtype='float32')
    glb_orb_en1 = np.empty(size*loc_orb_en1.size,dtype='float32')
    glb_orb_en2 = np.empty(size*loc_orb_en2.size,dtype='float32')
    glb_wt1 = np.empty(size*loc_wt2.size,dtype='float32')
    glb_wt2 = np.empty(size*loc_wt2.size,dtype='float32')
else:
    glb_en1 = None
    glb_en2 = None
    glb_orb_en1 = None
    glb_orb_en2 = None
    glb_wt1 = None
    glb_wt2 = None
comm.Barrier()

loc_en1 = np.asarray(loc_en1,dtype='float32')
loc_en2 = np.asarray(loc_en2,dtype='float32')
loc_orb_en1 = np.asarray(loc_orb_en1,dtype='float32')
loc_orb_en2 = np.asarray(loc_orb_en2,dtype='float32')
loc_wt1 = np.asarray(loc_wt1,dtype='float32')
loc_wt2 = np.asarray(loc_wt2,dtype='float32')

comm.Gather(loc_en1,glb_en1,root=0)
comm.Gather(loc_en2,glb_en2,root=0)
comm.Gather(loc_orb_en1,glb_orb_en1,root=0)
comm.Gather(loc_orb_en2,glb_orb_en2,root=0)
comm.Gather(loc_wt1,glb_wt1,root=0)
comm.Gather(loc_wt2,glb_wt2,root=0)


comm.Barrier()
if rank == 0:
    glb_en1 = glb_en1.reshape(size,rlx_steps).T
    glb_en2 = glb_en2.reshape(size,rlx_steps).T
    glb_orb_en1 = glb_orb_en1.reshape(size,rlx_steps).T
    glb_orb_en2 = glb_orb_en2.reshape(size,rlx_steps).T
    glb_wt1 = glb_wt1.reshape(size,rlx_steps).T
    glb_wt2 = glb_wt2.reshape(size,rlx_steps).T

    en1 = np.zeros((rlx_steps))
    en2 = np.zeros((rlx_steps))
    en_diff = np.zeros((rlx_steps))
    orb_en1 = np.zeros((rlx_steps))
    orb_en2 = np.zeros((rlx_steps))
    orb_en_diff = np.zeros((rlx_steps))
    
    for step in range(rlx_steps):

        en1[step] = sum(glb_en1[step,:])/sum(glb_wt1[step,:])
        en2[step] = sum(glb_en2[step,:])/sum(glb_wt2[step,:])
        en_diff[step] = en1[step] - en2[step]
        orb_en1[step] = sum(glb_orb_en1[step,:])/sum(glb_wt1[step,:])
        orb_en2[step] = sum(glb_orb_en2[step,:])/sum(glb_wt2[step,:])
        orb_en_diff[step] = orb_en1[step] - orb_en2[step]

        print(f'    {step+1} \t \t {en1[step]:.6f} \t {en2[step]:.6f} \t {en_diff[step]:.6f}'
              f'\t {orb_en1[step]:.6f} \t {orb_en2[step]:.6f} \t {orb_en_diff[step]:.6f}')
    
    now_time = time.time()
    print(f'# relaxation time: {now_time - init_time:.2f}')
comm.Barrier()

# rlx_step 	 system1_en 	 system2_en 	 en_diff 	 orb1_en 	 orb2_en 	 orb_en_diff
    0	 	 -152.062510	 -152.053292	 -0.009219
    1 	 	 -152.217429 	 -152.195371 	 -0.022058	 -0.034887 	 -0.033706 	 -0.001181
    2 	 	 -152.241769 	 -152.194260 	 -0.047509	 -0.036968 	 -0.030234 	 -0.006734
    3 	 	 -152.208033 	 -152.220627 	 0.012593	 -0.040572 	 -0.045845 	 0.005274
# relaxation time: 64.46


In [32]:
n_runs = 100
prop_steps = 10
seeds = random.randint(random.PRNGKey(options1["seed"]),
                    shape=(n_runs,), minval=0, maxval=10000*n_runs)

_,loc_orb_en1,loc_wt1,_,loc_orb_en2,loc_wt2 \
    = corr_sample.lno_cs_seeds_scan(seeds,prop_steps,
                                    prop_data1_init,ham_data1_init,prop1,trial1,wave_data1,-2,
                                    prop_data2_init,ham_data2_init,prop2,trial2,wave_data2,-2,
                                    MPI)

In [31]:
nwalkers = options1["n_walkers"]
comm.Barrier()
if rank == 0:
    # print()
    # print(f'# multiple independent post relaxation propagation with step size {dt}s')
    # if options["corr_samp"]:
    #     print('# correlated sampling')
    # else: print('# uncorrelated sampling')

    print(f'# tot_walkers: {nwalkers*size}, propagation steps: {prop_steps}, number of independent runs: {n_runs}')
    print('# step' 
        #   '\t system1_en \t error' 
        #   '\t \t system2_en \t error'
        #   '\t \t energy_diff \t error'
          '\t orb1_en \t error' 
          '\t \t orb2_en \t error'
          '\t \t orb_en_diff \t error')
    
comm.Barrier()

comm.Barrier()
if rank == 0:
    # glb_en1 = np.empty(size*loc_en1.size,dtype='float32')
    # glb_en2 = np.empty(size*loc_en2.size,dtype='float32')
    glb_orb_en1 = np.empty(size*loc_orb_en1.size,dtype='float32')
    glb_orb_en2 = np.empty(size*loc_orb_en2.size,dtype='float32')
    glb_wt1 = np.empty(size*loc_wt1.size,dtype='float32')
    glb_wt2 = np.empty(size*loc_wt2.size,dtype='float32')
else:
    # glb_en1 = None
    # glb_en2 = None
    glb_orb_en1 = None
    glb_orb_en2 = None
    glb_wt1 = None
    glb_wt2 = None
comm.Barrier()

# loc_en1 = np.asarray(loc_en1,dtype='float32')
# loc_en2 = np.asarray(loc_en2,dtype='float32')
loc_orb_en1 = np.asarray(loc_orb_en1,dtype='float32')
loc_orb_en2 = np.asarray(loc_orb_en2,dtype='float32')
loc_wt1 = np.asarray(loc_wt1,dtype='float32')
loc_wt2 = np.asarray(loc_wt2,dtype='float32')

# comm.Gather(loc_en1,glb_en1,root=0)
# comm.Gather(loc_en2,glb_en2,root=0)
comm.Gather(loc_orb_en1,glb_orb_en1,root=0)
comm.Gather(loc_orb_en2,glb_orb_en2,root=0)
comm.Gather(loc_wt1,glb_wt1,root=0)
comm.Gather(loc_wt2,glb_wt2,root=0)


comm.Barrier()
if rank == 0:
    # glb_en1 = glb_en1.reshape(size,n_runs,prop_steps).T
    # glb_en2 = glb_en2.reshape(size,n_runs,prop_steps).T
    glb_orb_en1 = glb_orb_en1.reshape(size,n_runs,prop_steps).T
    glb_orb_en2 = glb_orb_en2.reshape(size,n_runs,prop_steps).T
    glb_wt1 = glb_wt1.reshape(size,n_runs,prop_steps).T
    glb_wt2 = glb_wt2.reshape(size,n_runs,prop_steps).T

    # en1 = np.zeros((prop_steps,n_runs))
    # en2 = np.zeros((prop_steps,n_runs))
    # en_diff = np.zeros((prop_steps,n_runs))
    orb_en1 = np.zeros((prop_steps,n_runs))
    orb_en2 = np.zeros((prop_steps,n_runs))
    orb_en_diff = np.zeros((prop_steps,n_runs))

    for step in range(prop_steps):

        for run in range(n_runs):
            # en1[step,run] = sum(glb_en1[step,run,:])/sum(glb_wt1[step,run,:])
            # en2[step,run] = sum(glb_en2[step,run,:])/sum(glb_wt2[step,run,:])
            # en_diff[step,run] = en1[step,run] - en2[step,run]
            orb_en1[step,run] = sum(glb_orb_en1[step,run,:])/sum(glb_wt1[step,run,:])
            orb_en2[step,run] = sum(glb_orb_en2[step,run,:])/sum(glb_wt2[step,run,:])
            orb_en_diff[step,run] = orb_en1[step,run] - orb_en2[step,run]

        # en_mean1 = en1[step,:].mean()
        # en_mean2 = en2[step,:].mean()
        # en_diff_mean = en_diff[step,:].mean()
        # en_err1 = en1[step,:].std()/np.sqrt(n_runs)
        # en_err2 = en2[step,:].std()/np.sqrt(n_runs)
        # en_diff_mean_err = en_diff[step,:].std()/np.sqrt(n_runs)
        orb_en_mean1 = orb_en1[step,:].mean()
        orb_en_mean2 = orb_en2[step,:].mean()
        orb_en_diff_mean = orb_en_diff[step,:].mean()
        orb_en_err1 = orb_en1[step,:].std()/np.sqrt(n_runs)
        orb_en_err2 = orb_en2[step,:].std()/np.sqrt(n_runs)
        orb_en_diff_mean_err = orb_en_diff[step,:].std()/np.sqrt(n_runs)

        print(f'  {step+1}'
            #   f'\t {en_mean1:.6f} \t {en_err1:.6f}' 
            #   f'\t {en_mean2:.6f} \t {en_err2:.6f}'
            #   f'\t {en_diff_mean:.6f} \t {en_diff_mean_err:.6f}'
              f'\t {orb_en_mean1:.6f} \t {orb_en_err1:.6f}' 
              f'\t {orb_en_mean2:.6f} \t {orb_en_err2:.6f}'
              f'\t {orb_en_diff_mean:.6f} \t {orb_en_diff_mean_err:.6f}')

    end_time = time.time()
    print(f'# total run time: {end_time - init_time:.2f}')
comm.Barrier()

# tot_walkers: 20, propagation steps: 6, number of independent runs: 100
# step	 orb1_en 	 error	 	 orb2_en 	 error	 	 orb_en_diff 	 error


  1	 -0.035104 	 0.000899	 -0.035905 	 0.000907	 0.000801 	 0.000420
  2	 -0.040522 	 0.001093	 -0.041198 	 0.001103	 0.000676 	 0.000576
  3	 -0.040270 	 0.001044	 -0.040648 	 0.001161	 0.000379 	 0.000486
  4	 -0.041809 	 0.001094	 -0.042157 	 0.001184	 0.000349 	 0.000572
  5	 -0.041808 	 0.001209	 -0.043099 	 0.001269	 0.001291 	 0.000622
  6	 -0.041767 	 0.001936	 -0.042298 	 0.001907	 0.000531 	 0.000663
# total run time: 8111.99
