In [1]:
import itertools

import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import factorial

from vibrojet.keo import Gmat, batch_Gmat, com, eckart, batch_pseudo
from vibrojet.taylor import deriv_list

jax.config.update("jax_enable_x64", True)

In [2]:
# Masses of O, H, H atoms
masses = np.array([31.97207070, 1.00782505, 1.00782505])

# Equilibrium values of valence coordinates
r1, r2, alpha = 1.3358387, 1.3358387, 92.2705139*np.pi/180
x0 = jnp.array([r1, r2, alpha], dtype=jnp.float64)

# Valence-to-Cartesian coordinate transformation
#   input: array of three valence coordinates
#   output: array of shape (number of atoms, 3) containing Cartesian coordinates of atoms
# `eckart` rotates coordinates to the Eckart frame and corrects for center of mass
@eckart(x0, masses)
def valence_to_cartesian(internal_coords):
    r1, r2, a = internal_coords
    return jnp.array(
        [
            [0.0, 0.0, 0.0],
            [r1 * jnp.sin(a / 2), 0.0, r1 * jnp.cos(a / 2)],
            [-r2 * jnp.sin(a / 2), 0.0, r2 * jnp.cos(a / 2)],
        ]
    )

In [3]:
# Generate grid of coordinates
r1_arr = np.linspace(r1 - 0.5, r1 + 0.5, 10)
r2_arr = np.linspace(r2 - 0.5, r2 + 0.5, 10)
alpha_arr = np.linspace(alpha - 40 * np.pi / 180, alpha + 40 * np.pi / 180, 10)
xa, xb, xc = np.meshgrid(r1_arr, r2_arr, alpha_arr, indexing="ij")
x = np.column_stack([xa.ravel(), xb.ravel(), xc.ravel()])
print(np.shape(x))

(1000, 3)


In [4]:
pseudo_vals = batch_pseudo(x, masses, valence_to_cartesian)
print(np.shape(pseudo_vals))

(1000,)
