In [1]:
import sys
sys.path.insert(1, '../')

In [2]:
import itertools
import os

import jax
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

from jax import config
from scipy import optimize

from vibrojet.basis_utils2 import ContrBasis, HermiteBasis, AssociateLegendreBasis, FourierBasis, generate_prod_ind

from vibrojet.jet_prim import acos
from vibrojet.potentials import nh3_POK
from vibrojet.taylor import deriv_list
from vibrojet.keo import Gmat, com, pseudo

plt.rcParams.update(
    {"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern"]}
)

config.update("jax_enable_x64", True)

In [4]:
# Morse constant necessary for defining y-coordinates for stretches
a_morse = 2.0

def internal_to_y(q, q0):
    r1, r2, r3, s4, s5, rho = q
    y1 = 1 - jnp.exp(-a_morse * (r1 - q0[0]))
    y2 = 1 - jnp.exp(-a_morse * (r2 - q0[1]))
    y3 = 1 - jnp.exp(-a_morse * (r3 - q0[2]))
    y4 = s4
    y5 = s5
    y6 = jnp.sin(rho)
    return jnp.array([y1, y2, y3, y4, y5, y6])


def y_to_internal(y, q0):
    y1, y2, y3, y4, y5, y6 = y
    r1 = -jnp.log(1 - y1) / a_morse + q0[0]
    r2 = -jnp.log(1 - y2) / a_morse + q0[1]
    r3 = -jnp.log(1 - y3) / a_morse + q0[2]
    s4 = y4
    s5 = y5
    rho = np.pi / 2 - acos(y6)  # = asin(y6)
    return jnp.array([r1, r2, r3, s4, s5, rho])

def find_alpha_from_s_delta(s4, s5, delta, no_iter: int = 10):

    sqrt2 = jnp.sqrt(2.0)
    sqrt3 = jnp.sqrt(3.0)
    sqrt6 = jnp.sqrt(6.0)

    def calc_s_to_sin_delta(s6, s4, s5):
        alpha1 = (sqrt2 * s6 + 2 * s4) / sqrt6
        alpha2 = (sqrt2 * s6 - s4 + sqrt3 * s5) / sqrt6
        alpha3 = (sqrt2 * s6 - s4 - sqrt3 * s5) / sqrt6
        cos_alpha1 = jnp.cos(alpha1)
        cos_alpha2 = jnp.cos(alpha2)
        cos_alpha3 = jnp.cos(alpha3)
        sin_alpha1 = jnp.sin(alpha1)
        sin_alpha2 = jnp.sin(alpha2)
        sin_alpha3 = jnp.sin(alpha3)
        tau_2 = (
            1
            - cos_alpha1**2
            - cos_alpha2**2
            - cos_alpha3**2
            + 2 * cos_alpha1 * cos_alpha2 * cos_alpha3
        )
        norm_2 = (
            sin_alpha3**2
            + sin_alpha2**2
            + sin_alpha1**2
            + 2 * cos_alpha3 * cos_alpha1
            - 2 * cos_alpha2
            + 2 * cos_alpha2 * cos_alpha3
            - 2 * cos_alpha1
            + 2 * cos_alpha2 * cos_alpha1
            - 2 * cos_alpha3
        )
        return tau_2 / norm_2

    # initial value for s6
    alpha1 = 2 * jnp.pi / 3
    s6 = alpha1 * sqrt3
    sin_delta = jnp.sin(delta)
    sin_delta2 = sin_delta**2

    for _ in range(no_iter):
        f = calc_s_to_sin_delta(s6, s4, s5)
        eps = f - sin_delta2
        grad = jax.grad(calc_s_to_sin_delta)(s6, s4, s5)
        dx = eps / grad
        dx0 = dx
        s6 = s6 - dx0

    alpha1 = (sqrt2 * s6 + 2 * s4) / sqrt6
    alpha2 = (sqrt2 * s6 - s4 + sqrt3 * s5) / sqrt6
    alpha3 = (sqrt2 * s6 - s4 - sqrt3 * s5) / sqrt6

    return alpha1, alpha2, alpha3

@jax.jit
def poten_in_y(y, q0):
    q = y_to_internal(y, q0)
    r1, r2, r3, s4, s5, rho = q
    delta = rho - jnp.pi / 2
    alpha1, alpha2, alpha3 = find_alpha_from_s_delta(s4, s5, delta)
    v = nh3_POK.poten((r1, r2, r3, alpha1, alpha2, alpha3))
    return v

In [8]:
@jax.jit
def poten(q):
    r1, r2, r3, s4, s5, tau = q
    rho = tau + np.pi / 2

    beta1 = jnp.sqrt(6) / 3 * s4 + 2 * np.pi / 3
    beta2 = -1 / jnp.sqrt(6) * s4 + 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3
    beta3 = -1 / jnp.sqrt(6) * s4 - 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3

    cosrho = jnp.cos(rho)
    sinrho = jnp.sin(rho)
    cosrho2 = cosrho * cosrho
    sinrho2 = sinrho * sinrho

    cosalpha2 = cosrho2 + sinrho2 * jnp.cos(beta2)
    cosalpha3 = cosrho2 + sinrho2 * jnp.cos(beta3)
    cosalpha1 = cosrho2 + sinrho2 * jnp.cos(beta2 + beta3)
    alpha1 = acos(cosalpha1)
    alpha2 = acos(cosalpha2)
    alpha3 = acos(cosalpha3)
    v = nh3_POK.poten((r1, r2, r3, alpha1, alpha2, alpha3))
    return v

# masses of N, H1, H2, H3
masses = [14.00307400, 1.007825035, 1.007825035, 1.007825035]

# internal-to-Cartesian coordinate transformation
@com(masses)
def internal_to_cartesian(internal_coords):
    r1, r2, r3, s4, s5, rho = internal_coords
    delta = rho - jnp.pi / 2
    alpha1, alpha2, alpha3 = find_alpha_from_s_delta(s4, s5, delta)

    cos_rho = jnp.cos(rho)
    sin_rho = jnp.sin(rho)

    # beta3 = acos((jnp.cos(alpha3) - jnp.cos(rho) ** 2) / jnp.sin(rho) ** 2)
    # beta2 = acos((jnp.cos(alpha2) - jnp.cos(rho) ** 2) / jnp.sin(rho) ** 2)

    cos_beta3 = (jnp.cos(alpha3) - cos_rho**2) / sin_rho**2
    cos_beta2 = (jnp.cos(alpha2) - cos_rho**2) / sin_rho**2

    sin_beta3 = jnp.sin(acos(cos_beta3))
    sin_beta2 = jnp.sin(acos(cos_beta2))

    # sin_beta3 = jnp.sqrt(1 - cos_beta3**2)  # 0 < beta3 < pi
    # sin_beta2 = jnp.sqrt(1 - cos_beta2**2)  # 0 < beta2 < pi

    xyz = jnp.array(
        [
            [0.0, 0.0, 0.0],
            [r1 * sin_rho, 0.0, r1 * cos_rho],
            [r2 * sin_rho * cos_beta3, r2 * sin_rho * sin_beta3, r2 * cos_rho],
            [r3 * sin_rho * cos_beta2, -r3 * sin_rho * sin_beta2, r3 * cos_rho],
        ]
    )
    return xyz

Find equilibrium values of internal coordinates

In [9]:
vmin = optimize.minimize(poten, [1.1, 1.1, 1.1, 0.1, 0.1, 0.1])
q0 = vmin.x
v0 = vmin.fun
y0 = internal_to_y(q0, q0)
xyz = internal_to_cartesian(q0)
print("Reference values of internal coordinates:\n", q0)
print("Reference values of expansion y-coordinates:\n", y0)
print("Reference values of Cartesian coordinates:\n", xyz)

Reference values of internal coordinates:
 [1.01159999e+00 1.01159999e+00 1.01159999e+00 6.95693777e-09
 2.12708155e-09 3.85722364e-01]
Reference values of expansion y-coordinates:
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 6.95693777e-09
 2.12708155e-09 3.76228525e-01]
Reference values of Cartesian coordinates:
 [[-5.52272859e-10 -2.57034490e-10 -1.66435748e-01]
 [ 3.80592772e-01 -2.57034490e-10  7.70838857e-01]
 [-1.90296380e-01  3.29603014e-01  7.70838859e-01]
 [-1.90296384e-01 -3.29603010e-01  7.70838855e-01]]


Generate expansion power indices

In [10]:
ncoo = len(q0)
max_order = 2
powers = [np.arange(max_order + 1)] * ncoo
deriv_ind, deriv_mind = next(
    generate_prod_ind(powers, select=lambda ind: np.sum(ind) <= max_order)
)

print("max expansion power:", max_order)
print("number of expansion terms:", len(deriv_ind))

max expansion power: 2
number of expansion terms: 28


Generate expansion of PES in terms of internal coordinates

In [11]:
poten_file = f"nh3_poten_beta_y_coefs_{max_order}.npz"
if os.path.exists(poten_file):
    print(f"load potential expansion coefs from file {poten_file}")
    data = np.load(poten_file)
    poten_coefs = data['coefs']
else:
    print(f"calculate potential expansion coefs and save to {poten_file}")
    poten_coefs = deriv_list(poten_in_y, deriv_ind, y0, if_taylor=True)
    np.save(poten_file, poten_coefs)

load potential expansion coefs from file nh3_poten_beta_y_coefs_2.npz


In [15]:
gmat_file = f"nh3_gmat_beta_y_coefs_{max_order}.npz"
if os.path.exists(gmat_file):
    print(f"load G-matrix expansion coefs from file {gmat_file}")
    data = np.load(gmat_file)
    gmat_coefs = data['coefs']
else:
    print(f"calculate G-matrix expansion coefs and save to {poten_file}")
    gmat_coefs = deriv_list(
        lambda y: Gmat(y_to_internal(y), masses, internal_to_cartesian),
        deriv_ind,
        y0,
        if_taylor=True,
    )
    np.save(gmat_file, gmat_in_y_coefs)


load G-matrix expansion coefs from file nh3_gmat_beta_y_coefs_2.npz


In [35]:
mask = deriv_ind != 0
ind0 = np.where(mask.sum(axis=1) == 0)[0][0]
mu = np.diag(gmat_coefs[ind0])[:ncoo]

ind2 = np.array(
    [
        np.where((mask.sum(axis=1) == 1) & (deriv_ind[:, icoo] == 2))[0][0]
        for icoo in range(ncoo)
    ]
)
freq = poten_coefs[ind2] * 2

lin_a = jnp.sqrt(jnp.sqrt(mu / freq))
lin_b = q0

list_herm = [0,1,2]
list_leg = [3,4]
list_fourier = [5]

limits_r = [[0,jnp.inf],[0,jnp.inf],[0,jnp.inf],[-2.5,3.5],[-2.5,2.5],[-1.2,1.2]]

lin_a = jnp.array([jnp.sqrt(jnp.sqrt(mu[i] / freq[i])) if i in list_herm
                    else (limits_r[i][1]-limits_r[i][0])/np.pi if i in list_leg
                    else (limits_r[i][1]-limits_r[i][0])/(2*np.pi)
                    for i in range(ncoo)])

lin_b = jnp.array([q0[i] if i in list_herm
                    else limits_r[i][0] if i in list_leg
                    else limits_r[i][0]
                    for i in range(ncoo)])

print("x->r linear mapping parameters 'a':", lin_a)
print("x->r linear mapping parameters 'b':", lin_b)

# x->r linear mapping function
x_to_r_map = lambda x, icoo: lin_a[icoo] * x + lin_b[icoo]
jac_x_to_r_map = lambda x, icoo: np.ones_like(x) * lin_a[icoo]

#Functions for change of coordinates for Taylor expansion
r_to_y_pes_ = lambda x, icoo: internal_to_y(jnp.array(q0).at[icoo].set(x), q0)[icoo]
r_to_y_gmat_ = lambda x, icoo: internal_to_y(jnp.array(q0).at[icoo].set(x), q0)[icoo]

r_to_y_pes = jax.vmap(r_to_y_pes_,in_axes=(0,None))
r_to_y_gmat = jax.vmap(r_to_y_gmat_,in_axes=(0,None))


x->r linear mapping parameters 'a': [0.14152789 0.14152789 0.14152789 1.90985932 1.59154943 0.38197186]
x->r linear mapping parameters 'b': [ 1.01159999  1.01159999  1.01159999 -2.5        -2.5        -1.2       ]


In [36]:
nbas = [80] * ncoo
npoints = [101] * ncoo

p0, p1, p2, p3, p4, p5 = [
    HermiteBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_y_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo])
    if icoo in list_herm else
    AssociateLegendreBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_y_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo], m=1)
    if icoo in list_leg else
    FourierBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_y_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo])
    for icoo in range(ncoo)
]


In [37]:
sum_deriv_ind = np.sum(deriv_ind,axis=1)
poten_coefs_1d = lambda icoo: np.array([poten_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else 0.0 for i in range(len(sum_deriv_ind))])
gmat_coefs_1d = lambda icoo: np.array([gmat_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else np.zeros_like(gmat_coefs[i]) for i in range(len(sum_deriv_ind))])

c0, c1, c2, c3, c4, c5 = [
    ContrBasis(
        (icoo,), (p0, p1, p2, p3, p4, p5), lambda _: True, gmat_coefs_1d(icoo), poten_coefs_1d(icoo), store_int = True,
    )
    for icoo in range(ncoo)
]

Number of states: 80
batch no: 0  out of: 1
iteration time: 0.32
[ 1772.4304359   5209.70774439  8503.54030181 11653.92810815
 14660.87116341 17524.36946761 20244.42302072 22821.03182277
 25254.19587374 27543.91517363]
Saving integrals...
batch no: 0 out of: 1
iteration time: 0.01
Number of states: 80
batch no: 0  out of: 1
iteration time: 0.02
[ 1772.4304359   5209.70774438  8503.54030179 11653.92810813
 14660.87116339 17524.36946758 20244.42302069 22821.03182273
 25254.19587369 27543.91517358]
Saving integrals...
batch no: 0 out of: 1
iteration time: 0.01
Number of states: 80
batch no: 0  out of: 1
iteration time: 0.03
[ 1772.43043777  5209.70775     8503.54031116 11653.92812125
 14660.87118026 17524.3694882  20244.42304505 22821.03185084
 25254.19590555 27543.91520919]
Saving integrals...
batch no: 0 out of: 1
iteration time: 0.01
Number of states: 80
batch no: 0  out of: 1
iteration time: 0.03
[  741.95247706  2226.10821052  3709.58876453  5192.3934641
  6674.52163181  8155.9725877

In [46]:
# couple CH1 and CH2
p_coefs = np.array([1, 1, 1, 1, 1, 1])
pmax = 60
emax_trunc_ = 60000

e0,e1,e2,e3,e4,e5 = c0.enr,c1.enr,c2.enr,c3.enr,c4.enr,c5.enr
e_sum = [e0 - e0[0],e1 - e1[0],e2 - e2[0],e3 - e3[0],e4 - e4[0],e5 - e5[0],]
#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax

c012 = ContrBasis(
    (0, 1, 2),
    (c0, c1, c2, c3, c4, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    emax=30000,
    batch_size = 1000000,
    store_int = True,
)

e012 = c012.enr
print(e012[0], e012[0:5] - e012[0])

Number of states: 969
batch no: 0  out of: 1
iteration time: 1.86
[ 6484.88743915  9980.66926004 10107.62978074 10108.0491876
 13476.45108094 13603.41160164 13603.8310085  13730.37212234
 13730.7915292  13731.21093606]
Saving integrals...
batch no: 0 out of: 1
iteration time: 1.6
6484.887439149066 [   0.         3495.7818209  3622.74234159 3623.16174845 6991.56364179]


In [52]:
# couple CH1 and CH2
p_coefs = np.array([1, 1, 1, 1, 1, 1])
pmax = 60
emax_trunc_ = 60000

#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax

c34 = ContrBasis(
    (3, 4),
    (c0, c1, c2, c3, c4, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    emax=30000,
    batch_size = 1000000,
    store_int = True,
)

e34 = c34.enr
print(e34[0], e34[0:5] - e34[0])

Number of states: 662
batch no: 0  out of: 1
iteration time: 0.49
[4669.01333965 6381.29503916 6381.29504496 8087.6684551  8101.91140702
 8101.91140704 9802.38996118 9802.38997276 9830.84209984 9830.89555738]
Saving integrals...
batch no: 0 out of: 1
iteration time: 0.31
4669.0133396497595 [   0.         1712.28169951 1712.28170531 3418.65511545 3432.89806737]


In [53]:
e5 = c5.enr

print('e012',e012[0], e012[0:5] - e012[0])
print('e34',e34[0], e34[0:5] - e34[0])
print('e5',e5[0], e5[0:5] - e5[0])

e012 6484.887439149066 [   0.         3495.7818209  3622.74234159 3623.16174845 6991.56364179]
e34 4669.0133396497595 [   0.         1712.28169951 1712.28170531 3418.65511545 3432.89806737]
e5 545.0191963303099 [   0.         1090.42024668 2179.91669187 3268.48777966 4356.13194445]


In [56]:
# couple all together
p_coefs = np.array([1, 1, 1])
pmax = 300
emax_trunc_ = 10000
e_sum = [e012 - e012[0],e34 - e34[0],e5 - e5[0]]
#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax
c = ContrBasis(
    (0, 1, 2),
    (c012, c34, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    store_int = False,
    emax=30000,
    batch_size = 100000,
)
e = c.enr
print(e[0], e[0:20] - e[0])


Number of states: 229
batch no: 0  out of: 1
iteration time: 0.6
[ 7579.96128534  8662.02853184  9266.51697334  9266.51838097
  9740.51039791 10371.45781481 10371.4611865  10820.651369
 10905.12238972 10947.97052226]
7579.961285336338 [   0.         1082.0672465  1686.555688   1686.55709563 2160.54911258
 2791.49652948 2791.49990117 3240.69008366 3325.16110439 3368.00923692
 3368.01105746 3537.40217788 3635.34846431 3635.34984348 3898.59015068
 3898.59328554 4323.29137193 4459.81229012 4506.09748471 4506.09853061]
