In [1]:
import itertools

import jax
import jax.numpy as jnp

from vibrojet.eckart import eckart
from vibrojet.keo import Gmat
from vibrojet.taylor import deriv_list

jax.config.update("jax_enable_x64", True)

# Masses of O, H, H atoms
masses = [15.9994, 1.00782505, 1.00782505]

# Equilibrium values of valence coordinates
r1, r2, alpha = 0.958, 0.958, 1.824
q0 = [r1, r2, alpha]

# Valence-to-Cartesian coordinate transformation
#   input: array of three valence coordinates
#   output: array of shape (number of atoms, 3)
#           containing Cartesian coordinates of atoms


@eckart(q0, masses)
def valence_to_cartesian(q):
    r1, r2, a = q
    return jnp.array(
        [
            [0.0, 0.0, 0.0],
            [r1 * jnp.sin(a / 2), 0.0, r1 * jnp.cos(a / 2)],
            [-r2 * jnp.sin(a / 2), 0.0, r2 * jnp.cos(a / 2)],
        ]
    )


# Generate list of multi-indices specifying the integer exponents
# for each coordinate in the Taylor series expansion

max_order = 4  # max total expansion order
deriv_ind = [
    elem
    for elem in itertools.product(*[range(0, max_order + 1) for _ in range(len(q0))])
    if sum(elem) <= max_order
]
print("max expansion order:", max_order)
print("number of expansion terms:", len(deriv_ind))

# Function for computing kinetic G-matrix for given masses of atoms
# and internal coordinates
func = lambda x: Gmat(x, masses, valence_to_cartesian)

# Compute Taylor series expansion coefficients
Gmat_coefs = deriv_list(func, deriv_ind, q0, if_taylor=True)

max expansion order: 4
number of expansion terms: 35
Time for d= 0 : 10.84 s
Time for d= 1 : 0.02 s
Time for d= 2 : 21.88 s
Time for d= 3 : 39.31 s
Time for d= 4 : 68.37 s
