In [2]:
! export JAX_PLATFORMS=cpu

  pid, fd = os.forkpty()


# Rotational-vibrational solutions for H2S molecule (rotational cluster states)

In [3]:
import jax
import numpy as np
from jax import config
from jax import numpy as jnp
from numpy.polynomial.hermite import hermgauss
from scipy import optimize

from rovib.h2s_potential import potential
from rovib.keo import Molecule, batch_Gmat, batch_pseudo, com
from rovib.primbas import hermite
from rovib.symtop import rotme_cor, rotme_ovlp, rotme_rot
from rovib.vib_xy2 import vibrations_xy2
from rovib import c2v

config.update("jax_enable_x64", True)

## Vibrational coordinates

Define atomic masses and internal coordinates for describing vibrations. Here, we use the valence-bond coordinates, $r_1\equiv \text{S--H}_1$, $r_2\equiv \text{S--H}_2$, and $\alpha = \angle\text{H}_1\text{SH}_2$. The function `valence_bond_coordinates` is required to build the kinetic energy operator.

In [4]:
NCOO = 3 # number of vibrational coordinates
MASS_S = 31.97207070
MASS_H = 1.00782505

@com
def valence_bond_coordinates(coords):
    r1, r2, alpha = coords
    return jnp.array(
        [
            [0.0, 0.0, 0.0],
            [r1 * jnp.cos(alpha / 2), 0.0, r1 * jnp.sin(alpha / 2)],
            [r2 * jnp.cos(alpha / 2), 0.0, -r2 * jnp.sin(alpha / 2)],
        ]
    )

Molecule.masses = np.array([MASS_S, MASS_H, MASS_H])
Molecule.internal_to_cartesian = valence_bond_coordinates

We use Hermite functions $H_n(x)e^{-x^2/2}$ as the vibrational basis functions and employ Gauss-Hermite quadratures for computing matrix elements.

The Hermite functions are defined over the coordinate range $x\in(-\infty,\infty)$. To map $x$ into the vibrational valence bond coordinates $r_1\in(0,\infty)$, $r_2\in(0,\infty)$, and $\alpha=(0,\pi)$,
we use linear transformations $r_1=a_1x_1+b_1$, $r_2=a_2x_2+b_2$, $\alpha=a_3x_3+b_3$.
The parameters $a_1, b_1, ..., b_3$ are determined by mapping the vibrational Hamiltonian in valence-bond coordinates onto the harmonic oscillator Hamiltonian.

In [5]:
vmin = optimize.minimize(potential, [1.0, 1.0, np.pi / 2])
r0 = vmin.x
v0 = vmin.fun
print("mininum of the potential:", r0, v0)

freq = jnp.diag(jax.hessian(potential)(r0))
mu = jnp.diag(batch_Gmat(jnp.array([r0]))[0, :NCOO, :NCOO])
lin_a = jnp.sqrt(jnp.sqrt(mu / freq))
lin_b = r0

print("x->r linear mapping parameters 'a' and 'b':", lin_a, lin_b)

# x->r linear mapping function
x_to_r_map = lambda x: lin_a * x + lin_b

mininum of the potential: [1.3358387  1.3358387  1.61042427] -0.0007846164037194908
x->r linear mapping parameters 'a' and 'b': [0.11245619 0.11245619 0.17845517] [1.3358387  1.3358387  1.61042427]


## Vibrational basis set contraction

To optimize the vibrational basis set, we begin by solving a simplified vibrational problem for each individual vibrational coordinate.
For this, we set the basis functions for all coordinates to $e^{−x^2/2}$ except for one.
We then use the corresponding solutions obtained sequentially for each coordinate to construct a contracted product basis set.

In [6]:
# list of 1D primitive basis functions for each vibrational coordinate
list_psi = (hermite, hermite, hermite)

# number of primitive basis functions for each vibrational coordinate
nmax = 40

# number of quadrature points for each vibrational coordinate
npoints = 80

contr_vec = []

for icoo in range(NCOO):
    print(f"\nSolve for coordinate #{icoo}")

    # quanta, set quanta to zero for all coordinates except the `icoo`
    list_q = [np.arange(nmax) if i == icoo else [0] for i in range(NCOO)]

    # quadratures, reduce number of points for all coordinates except the `icoo`
    n = [npoints if i == icoo else 10 for i in range(NCOO)]
    n1, n2, n3 = n
    x1, w1 = hermgauss(n1)
    x2, w2 = hermgauss(n2)
    x3, w3 = hermgauss(n3)
    w1 /= np.exp(-(x1**2))
    w2 /= np.exp(-(x2**2))
    w3 /= np.exp(-(x3**2))
    list_x = (x1, x2, x3)
    list_w = (w1, w2, w3)

    # solve for basis functions for the `icoo` vibrational coordinate
    e, v, *_ = vibrations_xy2(
        list_psi,
        list_x,
        list_w,
        list_q,
        select_points=lambda x, w: True,
        select_quanta=lambda q: np.sum(q * np.array([1, 1, 1])) <= nmax,
        x_to_r_map=x_to_r_map,
        gmat=lambda x: batch_Gmat(x),
        pseudo=batch_pseudo,
        potential=potential,
    )

    # keep eigenvectors
    contr_vec.append(v)

    print(f"energies:")
    print(e[0], e - e[0])


Solve for coordinate #0
energies:
3337.3945251424393 [0.00000000e+00 2.62278235e+03 5.15020514e+03 7.58246436e+03
 9.91956242e+03 1.21613549e+04 1.43075890e+04 1.63579482e+04
 1.83123940e+04 2.01753782e+04 2.19804107e+04 2.38279199e+04
 2.58337642e+04 2.80450416e+04 3.04613694e+04 3.30740045e+04
 3.58769744e+04 3.88679012e+04 4.20472946e+04 4.54181980e+04
 4.89859992e+04 5.27577540e+04 5.67453996e+04 6.09584058e+04
 6.54167889e+04 7.01557847e+04 7.51299753e+04 8.05653155e+04
 8.63724076e+04 9.19469603e+04 1.00529194e+05 1.18644018e+05
 1.52570839e+05 2.14310500e+05 3.31561510e+05 5.69151477e+05
 1.08710878e+06 2.31554408e+06 5.58548925e+06 1.63735418e+07]

Solve for coordinate #1
energies:
3337.3945251416912 [0.00000000e+00 2.62278235e+03 5.15020514e+03 7.58246436e+03
 9.91956242e+03 1.21613549e+04 1.43075890e+04 1.63579482e+04
 1.83123940e+04 2.01753782e+04 2.19804107e+04 2.38279199e+04
 2.58337642e+04 2.80450416e+04 3.04613694e+04 3.30740045e+04
 3.58769744e+04 3.88679012e+04 4.2047

Transform primitive basis

In [7]:
# override primitive basis with contracted basis
list_psi = [
    lambda x, n: jnp.dot(hermite(x, n), contr_vec[0]),
    lambda x, n: jnp.dot(hermite(x, n), contr_vec[1]),
    lambda x, n: jnp.dot(hermite(x, n), contr_vec[2]),
]

## Vibrational energies and matrix elements

Using the optimized basis set from the previous step, we solve the vibrational problem employing a polyad basis set truncation.
In this approach, we include only those products of functions in the basis set for which the sum of the corresponding quanta, weighted by specific factors,
is less than or equal to a certain maximum value.
Specifically, we include basis functions that satisfy the condition $2n_1+2n_2+n_3\leq P_\text{max}$, where $n_1$, $n_2$, and $n_3$ are the quantum numbers for the two stretching and one bending coordinates, respectively.

In [8]:
# polyad number for basis set truncation 
pmax = 20

# quanta
list_q = [np.arange(nmax) for i in range(NCOO)]

# quadratures
n = [npoints for i in range(NCOO)]
n1, n2, n3 = n
x1, w1 = hermgauss(n1)
x2, w2 = hermgauss(n2)
x3, w3 = hermgauss(n3)
w1 /= np.exp(-(x1**2))
w2 /= np.exp(-(x2**2))
w3 /= np.exp(-(x3**2))
list_x = (x1, x2, x3)
list_w = (w1, w2, w3)

# solutions, including vibrational matrix elements of rotational part of kinetic energy operator
vib_enr, vib_vec, vib_quanta, grot_me, gcor_me = vibrations_xy2(
    list_psi,
    list_x,
    list_w,
    list_q,
    select_points=lambda x, w: True,
    select_quanta=lambda q: np.sum(q * np.array([2, 2, 1])) <= pmax,
    x_to_r_map=x_to_r_map,
    gmat=lambda x: batch_Gmat(x),
    pseudo=batch_pseudo,
    potential=potential,
    assign_c2v=True, # assign symmetries to vibrational states
)

nbas_vib = len(vib_quanta)

Print vibrational energies with asssignments

In [None]:
print("number of vibrational states:", nbas_vib)

ind = np.argmax(np.abs(vib_vec), axis=0)
zpe = vib_enr[0]
for e, q in zip(vib_enr, vib_quanta[ind]):
    print(e - zpe, q)

number of vibrational states: 506
0.0 ['0' '0' '0' 'A1']
1182.5695896006414 ['0' '0' '1' 'A1']
2353.9073088766704 ['0' '0' '2' 'A1']
2614.394866827381 ['1' '0' '0' 'B2']
2628.4633278424367 ['0' '1' '0' 'B2']
3513.705049529291 ['0' '0' '3' 'A1']
3779.1894940842135 ['1' '0' '1' 'A1']
3789.2699680649735 ['0' '1' '1' 'A1']
4661.605915283004 ['0' '0' '4' 'B2']
4932.68958994003 ['0' '1' '2' 'A1']
4939.130339236135 ['1' '0' '2' 'B2']
5145.0325080259245 ['2' '0' '0' 'B2']
5147.16722639663 ['0' '2' '0' 'A1']
5243.159164558614 ['1' '1' '0' 'A1']
5797.207682793023 ['0' '0' '5' 'A1']
6074.566899182262 ['1' '0' '3' 'B2']
6077.62728540151 ['0' '1' '3' 'B2']
6288.136259845905 ['2' '0' '1' 'A1']
6289.129795754327 ['0' '2' '1' 'B2']
6385.321551894391 ['1' '1' '1' 'B2']
6920.081854231556 ['0' '0' '6' 'A1']
7204.3099529200335 ['1' '0' '4' 'A1']
7204.437296445947 ['0' '1' '4' 'A1']
7419.8516538955355 ['2' '0' '2' 'A1']
7420.081813130413 ['0' '2' '2' 'A1']
7516.827921278722 ['1' '1' '2' 'A1']
7576.41581430

## Rotational basis and matrix elements

In [None]:
j_angmom = 10

# matrix elements <jk'|jk>
s_rot, _, ktau_quanta = rotme_ovlp(j_angmom)

# matrix elements <jk'|Ja*Jb|jk>
jab_rot, *_ = rotme_rot(j_angmom)

# matrix elements <jk'|i*Ja|jk>
ja_rot, *_ = rotme_cor(j_angmom)

nbas_rot = len(ktau_quanta)

print("rotational angular momentum:", j_angmom)
print("number of rotational states:", nbas_rot)

rotational angular momentum: 10
number of rotational states: 21


## Symmetry-adapted rovibrational basis and solutions

In [None]:
# combine vibrational and rotational quanta [(v1, v2, v3, vib_sym, k, tau, rot_sym), ...]
rovib_quanta = np.concatenate(
    (
        vib_quanta[:, None, :].repeat(nbas_rot, axis=1),
        ktau_quanta[None, :, :].repeat(nbas_vib, axis=0),
    ),
    axis=-1,
).reshape(-1, 7)

# mapping between rovibrational state index and indices of vibrational and rotational functions
rovib_ind = np.concatenate(
    (
        np.arange(nbas_vib)[:, None, None].repeat(nbas_rot, axis=1),
        np.arange(nbas_rot)[None, :, None].repeat(nbas_vib, axis=0),
    ),
    axis=-1,
).reshape(-1, 2)

# symmetry of rovibrational product prod_sym = vib_sym * rot_sym
prod_sym = np.array(
    [c2v.C2V_PRODUCT_TABLE[(sym1, sym2)] for (sym1, sym2) in rovib_quanta[:, (3, 6)]]
)
# update rovibrational quanta [(prod_sym, v1, v2, v3, vib_sym, k, tau, rot_sym), ...]
rovib_quanta = np.concatenate((prod_sym[:, None], rovib_quanta), axis=-1)

# identify indices of rovibrational product states for each symmetry
ind_sym = {sym: np.where(rovib_quanta[:, 0] == sym)[0] for sym in c2v.C2V_IRREPS}

# ... indices of vibrational and rotational product states for each symmetry
rovib_ind_sym = {sym: rovib_ind[ind] for sym, ind in ind_sym.items()}

Compute and diagonalise total Hamiltonian matrix for different symmetries

In [None]:
for sym, ind in rovib_ind_sym.items():
    vind, rind = ind.T

    rot_me = jnp.einsum(
        "ijab,ijab->ij", grot_me[np.ix_(vind, vind)], jab_rot[np.ix_(rind, rind)]
    )
    cor_me = jnp.einsum(
        "ija,ija->ij", gcor_me[np.ix_(vind, vind)], ja_rot[np.ix_(rind, rind)]
    )
    vib_me = jnp.diag(vib_enr[vind]) * s_rot[np.ix_(rind, rind)]
    hmat = vib_me + 0.5 * (rot_me + cor_me)

    e, _ = jnp.linalg.eigh(hmat)
    print(sym, e[:10] - zpe)

A1 [ 569.90926366  743.7686041   876.17557632  961.21210357 1013.51934513
 1094.34829073 1752.69674068 1937.09772795 2076.97263485 2165.62253671]
A2 [ 662.05144926  815.17049336  925.84473619  985.60653639 1051.54220369
 1850.51603166 2012.6103016  2129.03391837 2191.60095239 2264.09541252]
B1 [ 662.05145647  815.17946745  927.50865756 1010.29345649 1094.33712764
 1850.51604369 2012.62360602 2131.21317097 2219.63274223 2310.04937859]
B2 [ 569.90926373  743.76893476  876.33051783  970.60760031 1051.24807439
 1752.69674081 1937.09824792 2077.18916476 2176.98609361 2263.83596548]


Streamline for $J=0..J_\text{max}$

In [None]:
Jmax = 10

for j_angmom in range(0, Jmax):

    s_rot, _, ktau_quanta = rotme_ovlp(j_angmom)
    jab_rot, *_ = rotme_rot(j_angmom)
    ja_rot, *_ = rotme_cor(j_angmom)
    nbas_rot = len(ktau_quanta)

    print("\nrotational angular momentum:", j_angmom)

    rovib_quanta = np.concatenate(
        (
            vib_quanta[:, None, :].repeat(nbas_rot, axis=1),
            ktau_quanta[None, :, :].repeat(nbas_vib, axis=0),
        ),
        axis=-1,
    ).reshape(-1, 7)

    rovib_ind = np.concatenate(
        (
            np.arange(nbas_vib)[:, None, None].repeat(nbas_rot, axis=1),
            np.arange(nbas_rot)[None, :, None].repeat(nbas_vib, axis=0),
        ),
        axis=-1,
    ).reshape(-1, 2)

    prod_sym = np.array(
        [c2v.C2V_PRODUCT_TABLE[(sym1, sym2)] for (sym1, sym2) in rovib_quanta[:, (3, 6)]]
    )
    rovib_quanta = np.concatenate((prod_sym[:, None], rovib_quanta), axis=-1)

    ind_sym = {sym: np.where(rovib_quanta[:, 0] == sym)[0] for sym in c2v.C2V_IRREPS}
    rovib_ind_sym = {sym: rovib_ind[ind] for sym, ind in ind_sym.items()}

    for sym, ind in rovib_ind_sym.items():
        vind, rind = ind.T
        rot_me = jnp.einsum(
            "ijab,ijab->ij", grot_me[np.ix_(vind, vind)], jab_rot[np.ix_(rind, rind)]
        )
        cor_me = jnp.einsum(
            "ija,ija->ij", gcor_me[np.ix_(vind, vind)], ja_rot[np.ix_(rind, rind)]
        )
        vib_me = jnp.diag(vib_enr[vind]) * s_rot[np.ix_(rind, rind)]
        hmat = vib_me + 0.5 * (rot_me + cor_me)
        e, _ = jnp.linalg.eigh(hmat)
        print(sym, e[:10] - zpe)


rotational angular momentum: 0
A1 [   0.         1182.5695896  2353.90730888 2614.39486683 3513.70504953
 3779.18949408 4661.60591528 4932.68958994 5145.03250803 5243.15916456]
A2 []
B1 []
B2 [2628.46332784 3789.26996806 4939.13033924 5147.1672264  6077.6272854
 6289.12979575 7204.30995292 7420.08181313 7576.59892752 7779.35291637]

rotational angular momentum: 1
A1 [2647.54126491 3808.90368229 4959.36414899 5165.96321599 6098.51110543
 6308.4765805  7225.90061341 7440.02234903 7595.10928399 7797.8460545 ]
A2 [  15.10223923 1198.00741192 2369.71410962 2629.26569589 2642.08118346
 3529.91487099 3794.38608619 3803.0740387  4678.25484547 4948.24227916]
B1 [  13.76301869 1196.51407411 2368.04608448 2627.96216838 2643.27827413
 3528.05594666 3792.93979506 3804.40387351 4676.19161991 4946.63531256]
B2 [  19.37556179 1202.51369779 2374.46590698 2633.48603599 3534.9297447
 3798.84312128 4683.55525098 4952.95047311 5163.8324248  5261.9492171 ]

rotational angular momentum: 2
A1 [  38.07477143 