In [1]:
from ad_afqmc import config
from ad_afqmc.lno_afqmc.ccsd_pt2 import lno_afqmc
config.setup_jax()
MPI = config.setup_comm()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

ham_data, prop, trial, wave_data, sampler, options, _ = (
    lno_afqmc._prep_afqmc())

from jax import numpy as jnp
import numpy as np
mo_coeff = jnp.array(np.eye(trial.norb))
wave_data["mo_coeff"] = mo_coeff[:, :trial.nelec[0]]

# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Using GPU.
# System: Linux
# Node Name: sharmagroup-rn
# Release: 6.14.0-37-generic
# Version: #37~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov 20 10:25:38 UTC 2
# Machine: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 5
# nelec: 6
#
# n_eql: 3
# n_prop_steps: 50
# n_ene_blocks: 1
# n_sr_blocks: 5
# n_blocks: 100
# n_walkers: 10
# seed: 99
# walker_type: rhf
# trial: ccsd_pt2
# dt: 0.005
# free_projection: False
# use_gpu: True
# max_error: 0.001
# n_exp_terms: 6
# ene0: 0.0
# n_batch: 1
#


In [2]:
import time
from jax import numpy as jnp
from jax import random
init_time = time.time()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

### initialize propagation
seed = options["seed"]
init_walkers = None
# dm_up = jnp.array(wave_data["mo_coeff"][0] @ wave_data["mo_coeff"][0].T.conj())
# dm_dn = jnp.array(wave_data["mo_coeff"][1] @ wave_data["mo_coeff"][1].T.conj())
# trial_rdm1 = [dm_up, dm_dn]
trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1
ham_data = trial._build_measurement_intermediates(ham_data, wave_data)
ham_data = prop._build_propagation_intermediates(ham_data, trial, wave_data)

prop_data = prop.init_prop_data(trial, wave_data, ham_data, init_walkers)
if jnp.abs(jnp.sum(prop_data["overlaps"])) < 1.0e-6:
    raise ValueError(
        "Initial overlaps are zero. Pass walkers with non-zero overlap."
    )
prop_data["key"] = random.PRNGKey(seed + rank)

prop_data["overlaps"] = trial.calc_overlap(prop_data["walkers"], wave_data)
prop_data["n_killed_walkers"] = 0
prop_data["pop_control_ene_shift"] = prop_data["e_estimate"]
print(prop_data["e_estimate"])

-148.9632014280337


In [5]:
t2 = jnp.array(wave_data['t2'])
print(t2.shape)

(3, 2, 3, 2)


In [19]:
def modified_cholesky(mat: np.ndarray, max_error: float = 1e-6) -> np.ndarray:
    """Modified cholesky decomposition for a given matrix.

    Args:
        mat (np.ndarray): Matrix to decompose.
        max_error (float, optional): Maximum error allowed. Defaults to 1e-6.

    Returns:
        np.ndarray: Cholesky vectors.
    """
    diag = mat.diagonal()
    size = mat.shape[0]
    nchol_max = size
    chol_vecs = np.zeros((nchol_max, nchol_max))
    # ndiag = 0
    nu = np.argmax(diag)
    delta_max = diag[nu]
    Mapprox = np.zeros(size)
    chol_vecs[0] = np.copy(mat[nu]) / delta_max**0.5

    nchol = 0
    while abs(delta_max) > max_error and (nchol + 1) < nchol_max:
        Mapprox += chol_vecs[nchol] * chol_vecs[nchol]
        delta = diag - Mapprox
        nu = np.argmax(np.abs(delta))
        delta_max = np.abs(delta[nu])
        R = np.dot(chol_vecs[: nchol + 1, nu], chol_vecs[: nchol + 1, :])
        chol_vecs[nchol + 1] = (mat[nu] - R) / (delta_max + 1e-10) ** 0.5
        nchol += 1

    return chol_vecs[:nchol]

In [26]:
from ad_afqmc import pyscf_interface
t2 = np.array(wave_data['t2'])
ni, na = t2.shape[0], t2.shape[1]
t2 = t2.reshape(ni*na,ni*na)*100
chol_cut_t2 = 1e-8
dt2 = pyscf_interface.modified_cholesky(t2,max_error=chol_cut_t2)
# dt2 = dt2.reshape(-1,ni,na)

  chol_vecs[0] = np.copy(mat[nu]) / delta_max**0.5


In [27]:
print(dt2.shape)
print(dt2)

(0, 6)
[]


In [28]:
print(t2.shape)
print(t2)

(6, 6)
[[-2.10317084e+00  2.04895665e-07  1.97030439e+00 -4.66699467e-06
  -3.49069904e-05  1.05075880e+00]
 [ 6.32671507e-06 -1.83775671e+00 -1.35669146e-06  1.45397701e+00
  -4.56890993e-01 -9.73343652e-06]
 [ 8.01694502e-01 -7.81028937e-08 -7.51047973e-01  1.77898243e-06
   1.33059768e-05 -4.00532157e-01]
 [-2.41164084e-06  7.00522955e-01  5.17148711e-07 -5.54232377e-01
   1.74159412e-01  3.71022762e-06]
 [-6.77967802e-06  6.60491585e-13  6.35137628e-06 -1.50442944e-11
  -1.12524457e-10  3.38717436e-06]
 [ 2.03944874e-11 -5.92410210e-06 -4.37336384e-12  4.68696873e-06
  -1.47281132e-06 -3.13762270e-11]]


In [29]:
L = np.linalg.cholesky(t2)
print(L.shape)
print(L)

LinAlgError: Matrix is not positive definite

In [10]:
rt2 = jnp.einsum('dia,djb->iajb',dt2,dt2)
print(rt2)

[[[[0. 0.]
   [0. 0.]
   [0. 0.]]

  [[0. 0.]
   [0. 0.]
   [0. 0.]]]


 [[[0. 0.]
   [0. 0.]
   [0. 0.]]

  [[0. 0.]
   [0. 0.]
   [0. 0.]]]


 [[[0. 0.]
   [0. 0.]
   [0. 0.]]

  [[0. 0.]
   [0. 0.]
   [0. 0.]]]]
