# Examples: How to use jQMC modules.

## TREX-IO documentation: https://trex-coe.github.io/trexio/trex.html

## PySCF-forge: https://github.com/pyscf/pyscf-forge
TREX-IO is implemented in PySCF-forge. Please install it **from the GitHub repo.** [pip install git+https://github.com/pyscf/pyscf-forge]

In [19]:
import numpy as np

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


## pyscf-forge example

```Python
from pyscf import gto, scf
from pyscf.tools import trexio

filename = 'water_ccecp_ccpvqz.h5'

mol = gto.Mole()
mol.verbose  = 5
mol.atom     = '''
               O    5.00000000   7.14707700   7.65097100
               H    4.06806600   6.94297500   7.56376100
               H    5.38023700   6.89696300   6.80798400
               '''
mol.basis    = 'ccecp-ccpvqz'
mol.unit     = 'A'
mol.ecp      = 'ccecp'
mol.charge   = 0
mol.spin     = 0
mol.symmetry = False
mol.cart = True
mol.output   = 'water.out'
mol.build()

mf = scf.HF(mol)
mf.max_cycle=200
mf_scf = mf.kernel()

trexio.to_trexio(mf, filename)
```

## jQMC part

In [24]:
from jqmc.trexio_wrapper import read_trexio_file
from jqmc.atomic_orbital import compute_AOs_jax
from jqmc.molecular_orbital import compute_MOs_jax
from jqmc.jastrow_factor import Jastrow_two_body_data, Jastrow_data
from jqmc.wavefunction import Wavefunction_data
from jqmc.coulomb_potential import compute_bare_coulomb_potential_ion_ion_jax
from jqmc.coulomb_potential import compute_bare_coulomb_potential_el_el_jax
from jqmc.coulomb_potential import compute_bare_coulomb_potential_el_ion_jax
from jqmc.coulomb_potential import compute_ecp_local_parts_all_pairs_jax
from jqmc.coulomb_potential import compute_ecp_non_local_parts_all_pairs_jax
from jqmc.coulomb_potential import compute_ecp_non_local_parts_nearest_neighbors_jax

In [27]:
# read data from a TREX-IO file
structure_data, aos_data, mos_data_up, mos_data_dn, geminal_mo_data, coulomb_potential_data = read_trexio_file(
    "water_ccecp_ccpvqz.h5"
)

In [28]:
# Define the value of Jastrow
jastrow_twobody_data = Jastrow_two_body_data.init_jastrow_two_body_data(jastrow_2b_param=1.0)
jastrow_data = Jastrow_data(
    jastrow_one_body_data=None,
    jastrow_two_body_data=jastrow_twobody_data,
    jastrow_three_body_data=None,
)
wavefunction_data = Wavefunction_data(jastrow_data=jastrow_data, geminal_data=geminal_mo_data)

In [4]:
structure_data.get_info()

['**Structure_data',
 '  PBC flag = False',
 '  --------------------------------------------------',
 '  element, label, Z, x, y, z in cartesian (Bohr)',
 '  --------------------------------------------------',
 '  O, O, 8.0, -1.32695823, -0.10593853, 0.01878815',
 '  H, H, 1.0, -1.93166524, 1.60017432, -0.02171052',
 '  H, H, 1.0, 0.48664428, 0.07959809, 0.00986248',
 '  --------------------------------------------------']

In [5]:
aos_data.get_info()

['**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000',
 '  O p',
 '    0.452364 1.0000000',
 '  O p',
 '    0.148562 1.0000000

In [6]:
mos_data_up.get_info()

['**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000

In [7]:
mos_data_dn.get_info()

['**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000

In [30]:
wavefunction_data.get_info()

['**Geminal_data',
 '  dim. of lambda_matrix = (4, 4)',
 '  lambda_matrix is symmetric? = True',
 '  lambda_matrix_paired.shape = (4, 4)',
 '  lambda_matrix_unpaired.shape = (4, 0)',
 '  num_electron_up = 4',
 '  num_electron_dn = 4',
 '**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 

In [31]:
coulomb_potential_data.get_info()

['**Coulomb_potential_data', '  ecp_flag = True']

In [33]:
geminal_data.get_info()

['**Geminal_data',
 '  dim. of lambda_matrix = (4, 4)',
 '  lambda_matrix is symmetric? = True',
 '  lambda_matrix_paired.shape = (4, 4)',
 '  lambda_matrix_unpaired.shape = (4, 0)',
 '  num_electron_up = 4',
 '  num_electron_dn = 4',
 '**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 

In [34]:
structure_data.get_info()

['**Structure_data',
 '  PBC flag = False',
 '  --------------------------------------------------',
 '  element, label, Z, x, y, z in cartesian (Bohr)',
 '  --------------------------------------------------',
 '  O, O, 8.0, -1.32695823, -0.10593853, 0.01878815',
 '  H, H, 1.0, -1.93166524, 1.60017432, -0.02171052',
 '  H, H, 1.0, 0.48664428, 0.07959809, 0.00986248',
 '  --------------------------------------------------']

In [49]:
num_ele_up = 4
num_ele_dn = 4
r_up_carts = np.random.rand(num_ele_up, 3)
r_dn_carts = np.random.rand(num_ele_dn, 3)

# compute_AOs_jax
```
def compute_AOs_jax(aos_data: AOs_sphe_data | AOs_cart_data, r_carts: jnpt.ArrayLike) -> jax.Array:
    """Compute AO values at the given r_carts.

    The method is for computing the value of the given atomic orbital at r_carts

    Args:
        ao_datas (AOs_data): an instance of AOs_data
        r_carts (jnpt.ArrayLike): Cartesian coordinates of electrons (dim: N_e, 3)

    Returns:
        jax.Array: Arrays containing values of the AOs at r_carts. (dim: num_ao, N_e)
    """
```

In [50]:
ao_up_values = compute_AOs_jax(aos_data=aos_data, r_carts=r_up_carts)
ao_up_values.shape

(114, 4)

In [51]:
ao_dn_values = compute_AOs_jax(aos_data=aos_data, r_carts=r_dn_carts)
ao_dn_values.shape

(114, 4)

# compute_MOs_jax
```
def compute_MOs_jax(mos_data: MOs_data, r_carts: jnpt.ArrayLike) -> jax.Array:
    """The class contains information for computing molecular orbitals at r_carts simlunateously.

    Args:
        mos_data (MOs_data): an instance of MOs_data
        r_carts (jnpt.ArrayLike): Cartesian coordinates of electrons (dim: N_e, 3)

    Returns:
        Arrays containing values of the MOs at r_carts. (dim: num_mo, N_e)
    """
```

In [52]:
mo_up_values = compute_MOs_jax(mos_data=mos_data_up, r_carts=r_up_carts)
mo_up_values.shape

(4, 4)

In [53]:
mo_dn_values = compute_MOs_jax(mos_data=mos_data_dn, r_carts=r_dn_carts)
mo_dn_values.shape

(4, 4)

# compute V_bare (ion-ion)
```
@jit
def compute_bare_coulomb_potential_ion_ion_jax(
    coulomb_potential_data: Coulomb_potential_data,
) -> float:
```

In [54]:
V_ii = compute_bare_coulomb_potential_ion_ion_jax(coulomb_potential_data=coulomb_potential_data)

# compute V_bare (ele-ele)
```
@jit
def compute_bare_coulomb_potential_el_el_jax(
    r_up_carts: jnpt.ArrayLike,
    r_dn_carts: jnpt.ArrayLike,
) -> float:
```

In [55]:
V_ee = compute_bare_coulomb_potential_el_el_jax(r_up_carts=r_up_carts, r_dn_carts=r_dn_carts)

# compute V_bare (ele-ion)
```
@jit
def compute_bare_coulomb_potential_el_ion_jax(
    coulomb_potential_data: Coulomb_potential_data,
    r_up_carts: jnpt.ArrayLike,
    r_dn_carts: jnpt.ArrayLike,
) -> float:
    """See compute_bare_coulomb_potential_api."""
```

In [56]:
V_ei = compute_bare_coulomb_potential_el_ion_jax(
    coulomb_potential_data=coulomb_potential_data, r_up_carts=r_up_carts, r_dn_carts=r_dn_carts
)

# compute V_ecp_local_part
```
@jit
def compute_ecp_local_parts_all_pairs_jax(
    coulomb_potential_data: Coulomb_potential_data,
    r_up_carts: jnpt.ArrayLike,
    r_dn_carts: jnpt.ArrayLike,
) -> float:
    """Compute ecp local parts, considering all nucleus-electron pairs.

    The method is for computing the local part of the given ECPs at (r_up_carts, r_dn_carts).
    A much faster implementation using JAX.

    Args:
        coulomb_potential_data (Coulomb_potential_data): an instance of Coulomb_potential_data
        r_up_carts (jnpt.ArrayLike): Cartesian coordinates of up-spin electrons (dim: N_e^{up}, 3)
        r_dn_carts (jnpt.ArrayLike): Cartesian coordinates of dn-spin electrons (dim: N_e^{dn}, 3)

    Returns:
        float: The sum of local part of the given ECPs with r_up_carts and r_dn_carts.
    """
```

In [57]:
V_local = compute_ecp_local_parts_all_pairs_jax(
    coulomb_potential_data=coulomb_potential_data, r_up_carts=r_up_carts, r_dn_carts=r_dn_carts
)

In [58]:
V_local

Array(0.00355238, dtype=float64)

# compute V_ecp_non_local_part (if you need all neighbors)
```
@partial(jit, static_argnums=(4))
def compute_ecp_non_local_parts_all_pairs_jax(
    coulomb_potential_data: Coulomb_potential_data,
    wavefunction_data=wavefunction_data,
    r_up_carts: jnpt.ArrayLike,
    r_dn_carts: jnpt.ArrayLike,
    RT: jnpt.ArrayLike,
    Nv: int = Nv_default,
) -> tuple[list, list, list]:
    """Compute ecp non-local parts using JAX, considering all nucleus-electron pairs.

    The method is for computing the non-local part of the given ECPs at (r_up_carts, r_dn_carts).

    Args:
        coulomb_potential_data (Coulomb_potential_data): an instance of Coulomb_potential_data
        wavefunction_data (Wavefunction_data): an instance of Wavefunction_data
        r_up_carts (jnpt.ArrayLike): Cartesian coordinates of up-spin electrons (dim: N_e^{up}, 3)
        r_dn_carts (jnpt.ArrayLike): Cartesian coordinates of dn-spin electrons (dim: N_e^{dn}, 3)
        RT (jnpt.ArrayLike): Rotation matrix. equiv R.T
        Nv (int): The number of quadrature points for the spherical part.
        flag_determinant_only (bool): If True, only the determinant part is considered for the non-local ECP part.

    Returns:
        list[jax.Array]: The list of grids for up electrons on which the non-local part is computed.
        list[jax.Array]: The list of grids for dn electrons on which the non-local part is computed.
        list[float]: The list of non-local part of the given ECPs with r_up_carts and r_dn_carts.
        float: The sum of non-local part of the given ECPs with r_up_carts and r_dn_carts.
    """
```

In [60]:
mesh_non_local_ecp_part_r_up_carts, mesh_non_local_ecp_part_r_dn_carts, V_nonlocal, sum_V_nonlocal = (
    compute_ecp_non_local_parts_all_pairs_jax(
        coulomb_potential_data=coulomb_potential_data,
        wavefunction_data=wavefunction_data,
        r_up_carts=r_up_carts,
        r_dn_carts=r_dn_carts,
        RT=np.eye(3),
        Nv=6,
    )
)

In [61]:
mesh_non_local_ecp_part_r_up_carts.shape

(144, 4, 3)

In [62]:
mesh_non_local_ecp_part_r_dn_carts.shape

(144, 4, 3)

In [63]:
V_nonlocal.shape

(144,)

In [64]:
sum_V_nonlocal

Array(2.97560322e-12, dtype=float64)

# compute V_ecp_non_local_part (if you need only several nearest neighbors: NN)
```
@partial(jit, static_argnums=(5, 6, 7))
def compute_ecp_non_local_parts_nearest_neighbors_jax(
    coulomb_potential_data: Coulomb_potential_data,
    wavefunction_data: Wavefunction_data,
    r_up_carts: jnpt.ArrayLike,
    r_dn_carts: jnpt.ArrayLike,
    RT: jnpt.ArrayLike,
    NN: int = NN_default,
    Nv: int = Nv_default,
    flag_determinant_only: bool = False,
) -> tuple[list, list, list, float]:
    """Compute ecp non-local parts.

    The method is for computing the non-local part of the given ECPs at (r_up_carts, r_dn_carts)
    with a cutoff considering only up to NN-th nearest neighbors.

    Args:
        coulomb_potential_data (Coulomb_potential_data): an instance of Coulomb_potential_data
        wavefunction_data (Wavefunction_data): an instance of Wavefunction_data
        r_up_carts (jnpt.ArrayLike): Cartesian coordinates of up-spin electrons (dim: N_e^{up}, 3)
        r_dn_carts (jnpt.ArrayLike): Cartesian coordinates of dn-spin electrons (dim: N_e^{dn}, 3)
        RT (jnpt.ArrayLike): Rotation matrix. equiv R.T
        NN (int): Consider only up to N-th nearest neighbors.
        Nv (int): The number of quadrature points for the spherical part.
        flag_determinant_only (bool): If True, only the determinant part is considered for the non-local ECP part.

    Returns:
        list[jax.Array]: The list of grids for up electrons on which the non-local part is computed.
        list[jax.Array]: The list of grids for dn electrons on which the non-local part is computed.
        list[float]: The list of non-local part of the given ECPs with r_up_carts and r_dn_carts.
        float: sum of the V_nonlocal
    """
```

In [65]:
mesh_non_local_ecp_part_r_up_carts, mesh_non_local_ecp_part_r_dn_carts, V_nonlocal, sum_V_nonlocal = (
    compute_ecp_non_local_parts_nearest_neighbors_jax(
        coulomb_potential_data=coulomb_potential_data,
        wavefunction_data=wavefunction_data,
        r_up_carts=r_up_carts,
        r_dn_carts=r_dn_carts,
        RT=np.eye(3),
        Nv=6,
        NN=1,
    )
)

In [66]:
mesh_non_local_ecp_part_r_up_carts.shape

(48, 4, 3)

In [67]:
mesh_non_local_ecp_part_r_dn_carts.shape

(48, 4, 3)

In [68]:
V_nonlocal.shape

(48,)

In [69]:
sum_V_nonlocal

Array(0., dtype=float64)