In [3]:
! export JAX_PLATFORMS=cpu

  pid, fd = os.forkpty()


# Rovibrational matrix elements of Cartesian tensor operators

Compute matrix elements of spin-rotation tensors on $\text{H}_1$ and $\text{H}_2$ atoms and electric dipole moment of $\text{H}_2\text{S}$ in the basis of rovibrational states.

In [4]:
from typing import Dict, List

import h5py
import jax
import numpy as np
from jax import config
from jax import numpy as jnp
from rovib.c2v import C2V_IRREPS
from rovib.cartens import CART_IND
from rovib.symtop import threej_wang
from scipy.sparse import csr_matrix

config.update("jax_enable_x64", True)


The rovibrational matrix elements of a Cartesian tensor operator $T$ of rank $\Omega$ are calculated using the following formula

$$
T_{\omega}^{(J',l',J,l)}=\sum_{v',k'}\sum_{v,k}c_{v',k'}^{(J',\Gamma',l')*}c_{v,k}^{(J,\Gamma,l)}(-1)^{k'}\sum_{\sigma=-\omega}^{\omega}\sum_{\alpha}
\left(\begin{array}{ccc}
J & \omega & J' \\
k & \sigma & -k'\\
\end{array}\right)
U_{\omega\sigma,\alpha}^{(\Omega)}\langle v'|\bar{T}_{\alpha}|v\rangle.
$$

Here, $\omega=0..\Omega$, $\alpha=x,y,z$ and $xx,xy,xz,...zz$ for $\Omega=1$ and 2, respectively,
$\bar{T}_{\alpha}$ is the electronic tensor (e.g., dipole moment, see `h2s_dipole.py`, or spin-rotation tensor, see `h2s_spinrot.py`),
$v$ and $k$ denote the vibrational and rotational quanta,
$J$ and $\Gamma$ denote 'good' quantum numbers of the rotational angular momentum and the state symmetry, and $l$ is the state running index within
a set of states with the same $J$ and $\Gamma$.
The rovibrational state expansion coefficients $c_{v,k}^{(J,\Gamma,l)}$
are computed in `h2s_rovib.ipynb` and stored in seperate files
for different values of $J$ quanta.
The matrix $U_{\omega\sigma,\alpha}^{(\Omega)}$ transforms Cartesian tensor of rank $\Omega$ into its spherical-tensor representation
(see `cartens.UMAT_CART_TO_SPHER`).

The rotational part of the expression
$$
R_{\omega,k',k,\alpha\beta}^{(J',J)} = \sum_{\sigma=-\omega}^{\omega}
\left(\begin{array}{ccc}
J & \omega & J' \\
k & \sigma & -k'\\
\end{array}\right)
U_{\omega\sigma,\alpha\beta}^{(\Omega)}
$$
is computed by the `symtop.threej_wang` function where it is transformed from the symmetric-top basis $|J,k\rangle$ ($k=-J..J$) into the so-called Wang's symmetry-adapted basis $|J,k,\tau\rangle$ ($k=0..J$, $\tau=0,1$ is partity as $(-1)^\tau$).

The vibrational matrix elements
$$
V_{v',v,\alpha} = \langle v'|\bar{T}_{\alpha}|v\rangle
$$
are computed in `h2s_rovib.ipynb` and stored in file.

Define functions to load the energies and quantum numbers for selected rovibrational states (`rovib_states`) and to compute the above matrix elements (`tensor_rovib_me`).

In [5]:
def rovib_states(j: int, state_ind_list: Dict[str, List[int]] = None, pmax: int = 20):
    """Reads rovibrational energies and quanta stored in files for different values
    of the J quantum number (the files are created in `h2s_rovib.ipynb`).

    Optionally, `state_ind_list[symmetry][<list of state indices>]`
    can be used to select specific states by their indices
    for different symmetries.
    """
    h5 = h5py.File(f"h2s_coefficients_pmax{pmax}_j{j}.h5", "r")
    energies = {}
    quanta = {}
    for sym in h5["energies"].keys():
        enr = h5["energies"][sym][:]
        coefs = h5["coefficients"][sym][:]
        vind = h5["vib-indices"][sym][:]
        rind = h5["rot-indices"][sym][:]
        qua = np.array(
            [elem[0].decode("utf-8").split(",") for elem in h5["quanta"][sym][:]]
        )
        if state_ind_list is not None and sym in state_ind_list:
            energies[sym] = enr[state_ind_list[sym]]
            quanta[sym] = qua[state_ind_list[sym]]
        else:
            energies[sym] = enr
            quanta[sym] = qua
    return energies, quanta

In [6]:
def tensor_rovib_me(
    rank: int,
    j1: int,
    j2: int,
    vib_me: np.ndarray,
    state_ind_list1: Dict[str, List[int]] = None,
    state_ind_list2: Dict[str, List[int]] = None,
    pmax: int = 20,
    linear: bool = False,
    tol: float = 1e-12,
):
    """Computes rovibrational matrix elements of a Cartesian tensor operator.

    This function calculates the rovibrational matrix elements using wavefunctions stored in
    separate files for different values of the J quantum number (the files are creaed in `h2s_rovib.ipynb`).
    The computation is performed for the specified bra (`j1`) and ket (`j2`) values of the J
    quantum number.

    Optionally, `state_ind_list1[symmetry][<list of state indices>]` and `state_ind_list2[...]`
    can be used to select specific bra and ket states by their indices for different symmetries.
    """

    @jax.jit
    def _matelem(vind1, vind2, rind1, rind2, coefs1, coefs2):
        vib_me_ = vib_me2[jnp.ix_(vind1, vind2)]
        me = []
        for omega in rot_me.keys():
            me_ = jnp.einsum(
                "ijc,ijc->ij", vib_me_, rot_me[omega][jnp.ix_(rind1, rind2)]
            )
            me.append(jnp.einsum("ik,ij,jl->kl", jnp.conj(coefs1), me_, coefs2))
        me = jnp.moveaxis(jnp.array(me), 0, -1)
        me = jnp.where(jnp.abs(me) < tol, 0.0, me)
        return me

    # determine the order of Cartesian indices in the Cartesian-to-spherical tensor
    #   transformation matrix (in cartens.CART_IND and symtop.threej_wang)
    cart_ind = [["xyz".index(x) for x in elem] for elem in CART_IND[rank]]

    # reshape vibrational matrix elements such that the order of Cartesian indices
    #   correspond to the order in symtop.threej_wang output
    if rank == 1:
        vib_me2 = jnp.moveaxis(jnp.array([vib_me[:, :, i] for (i,) in cart_ind]), 0, -1)
    elif rank == 2:
        vib_me2 = jnp.moveaxis(
            jnp.array([vib_me[:, :, i, j] for (i, j) in cart_ind]), 0, -1
        )
    else:
        raise ValueError(
            f"Index mapping for tensor of rank = {rank} is not implemented"
        )

    # compute rotational matrix elements of three-j symbol contracted with
    #   Cartesian-to-spherical tensor transformation matrix
    jktau_list1, jktau_list2, rot_me = threej_wang(rank, j1, j2, linear=linear)
    # rot_me[omega].shape = (2*j1+1, 2*j2+1, ncart)

    h5_1 = h5py.File(f"h2s_coefficients_pmax{pmax}_j{j1}.h5", "r")
    h5_2 = h5py.File(f"h2s_coefficients_pmax{pmax}_j{j2}.h5", "r")

    res = {}

    for sym1 in h5_1["energies"].keys():
        enr1 = h5_1["energies"][sym1][:]
        coefs1 = h5_1["coefficients"][sym1][:]
        vind1 = h5_1["vib-indices"][sym1][:]
        rind1 = h5_1["rot-indices"][sym1][:]
        qua1 = np.array(
            [elem[0].decode("utf-8").split(",") for elem in h5_1["quanta"][sym1][:]]
        )

        if state_ind_list1 is not None:
            if sym1 in state_ind_list1:
                ind = state_ind_list1[sym1]
                enr1 = enr1[ind]
                coefs1 = coefs1[:, ind]
            else:
                continue

        for sym2 in h5_2["energies"].keys():
            enr2 = h5_2["energies"][sym2][:]
            coefs2 = h5_2["coefficients"][sym2][:]
            vind2 = h5_2["vib-indices"][sym2][:]
            rind2 = h5_2["rot-indices"][sym2][:]
            qua2 = np.array(
                [elem[0].decode("utf-8").split(",") for elem in h5_2["quanta"][sym2][:]]
            )

            if state_ind_list2 is not None:
                if sym2 in state_ind_list2:
                    ind = state_ind_list2[sym2]
                    enr2 = enr2[ind]
                    coefs2 = coefs2[:, ind]
                else:
                    continue

            me = _matelem(vind1, vind2, rind1, rind2, coefs1, coefs2)
            if np.count_nonzero(me) > 0:
                nstates2 = np.prod(me.shape[:2])
                res[(sym1, sym2)] = csr_matrix(me.reshape(nstates2, -1))

    return res

Using the functions above, compute matrix elements of spin-rotation tensors and dipole moment for $J=50..60$ and lowest 1000 rovibrational states for each symmetry, and store matrix elements in files.

In [7]:
pmax = 20
min_J = 50
max_J = 60

state_ind_list = {sym: np.arange(1000) for sym in C2V_IRREPS}

# read vibrational matrix elements of spin-rotation tensors and dipole moment

with h5py.File(f"h2s_vibme_pmax{pmax}.h5", "r") as h5:
    sr_h1_vib = h5["spin-rotation"]["h1"][:]
    sr_h2_vib = h5["spin-rotation"]["h2"][:]
    dipole_vib = h5["dipole-moment"][:]

# compute rovibrational matrix elements and store them in file

with h5py.File(f"h2s_enr_pmax{pmax}.h5", "w") as h5:
    for J in range(min_J, max_J + 1):
        enr, qua = rovib_states(J, state_ind_list=state_ind_list)
        print(
            f"store energies for J = {J}, no. states = {[(sym, len(state_ind_list[sym])) for sym in state_ind_list]}"
        )
        for sym in enr:
            qua_str = [",".join(elem) for elem in qua[sym]]
            max_len = max([len(elem) for elem in qua_str])
            qua_ascii = [elem.encode("ascii", "ignore") for elem in qua_str]
            h5.create_dataset(f"energies/{J}/{sym}", data=enr[sym])
            h5.create_dataset(
                f"quanta/{J}/{sym}", (len(qua_ascii), 1), f"S{max_len}", data=qua_ascii
            )

for J1 in range(min_J, max_J + 1):
    for J2 in range(min_J, max_J + 1):
        if J1 < 54:
            continue  # restart from certain J1
        dJ = abs(J1 - J2)

        if dJ <= 1:
            dipole_me = tensor_rovib_me(
                1,
                J1,
                J2,
                dipole_vib,
                state_ind_list1=state_ind_list,
                state_ind_list2=state_ind_list,
                tol=1e-6,
            )  # Debye units
        else:
            dipole_me = {}

        if dJ <= 2:
            sr1_me = tensor_rovib_me(
                2,
                J1,
                J2,
                sr_h1_vib,
                state_ind_list1=state_ind_list,
                state_ind_list2=state_ind_list,
                tol=1e-3,
            )  # kHz units
            sr2_me = tensor_rovib_me(
                2,
                J1,
                J2,
                sr_h2_vib,
                state_ind_list1=state_ind_list,
                state_ind_list2=state_ind_list,
                tol=1e-3,
            )
        else:
            sr1_me = {}
            sr2_me = {}

        if dipole_me or sr1_me or sr2_me:
            print(f"store matrix elements for J1 = {J1} J2 = {J2}, delta J = {dJ}")
            with h5py.File(f"h2s_me_pmax{pmax}_j{J1}_j{J2}.h5", "w") as h5:
                for label, oper in zip(
                    ("dipole", "spin-rotation-H1", "spin-rotation-H2"),
                    (dipole_me, sr1_me, sr2_me),
                ):
                    for (sym1, sym2), me in oper.items():
                        print(
                            f"{label}, sym = {(sym1, sym2)}, no. elements = {len(me.data)}  "
                            + f"{round(len(me.data)/np.prod(me.shape)*100,3)}%"
                        )
                        h5.create_dataset(f"{label}/{sym1}/{sym2}/data", data=me.data)
                        h5.create_dataset(
                            f"{label}/{sym1}/{sym2}/indices", data=me.indices
                        )
                        h5.create_dataset(
                            f"{label}/{sym1}/{sym2}/indptr", data=me.indptr
                        )
                        h5.create_dataset(f"{label}/{sym1}/{sym2}/shape", data=me.shape)

store energies for J = 50, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 51, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 52, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 53, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 54, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 55, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 56, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 57, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 58, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 59, no. states = [('A1', 1000), ('A2', 1000), ('B1', 1000), ('B2', 1000)]
store energies for J = 60, no.