In [1]:
import sys
sys.path.insert(1, '../')

In [2]:
import itertools
import os

import jax
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

from jax import config
from scipy import optimize

from vibrojet.basis_utils2 import ContrBasis, HermiteBasis, AssociateLegendreBasis, FourierBasis, generate_prod_ind

from vibrojet.jet_prim import acos
from vibrojet.potentials import nh3_POK
from vibrojet.taylor import deriv_list
from vibrojet.keo import Gmat, com, pseudo

plt.rcParams.update(
    {"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern"]}
)

config.update("jax_enable_x64", True)

In [3]:
@jax.jit
def poten(q):
    r1, r2, r3, s4, s5, tau = q
    rho = tau + np.pi / 2

    beta1 = jnp.sqrt(6) / 3 * s4 + 2 * np.pi / 3
    beta2 = -1 / jnp.sqrt(6) * s4 + 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3
    beta3 = -1 / jnp.sqrt(6) * s4 - 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3

    cosrho = jnp.cos(rho)
    sinrho = jnp.sin(rho)
    cosrho2 = cosrho * cosrho
    sinrho2 = sinrho * sinrho

    cosalpha2 = cosrho2 + sinrho2 * jnp.cos(beta2)
    cosalpha3 = cosrho2 + sinrho2 * jnp.cos(beta3)
    cosalpha1 = cosrho2 + sinrho2 * jnp.cos(beta2 + beta3)
    alpha1 = acos(cosalpha1)
    alpha2 = acos(cosalpha2)
    alpha3 = acos(cosalpha3)
    v = nh3_POK.poten((r1, r2, r3, alpha1, alpha2, alpha3))
    return v


vmin = optimize.minimize(poten, [1.1, 1.1, 1.1, 0.1, 0.1, 0.1])
q0 = vmin.x
v0 = vmin.fun

print("equilibrium coordinates:", q0)
print("min of the potential:", v0)

equilibrium coordinates: [1.01159999e+00 1.01159999e+00 1.01159999e+00 6.95763752e-09
 2.12715030e-09 3.85722364e-01]
min of the potential: 3.4875112211089974e-11


In [4]:
# y-coordinates for expansion of PES
# Morse constant necessary for defining y-coordinates for stretches
a_morse = 2.0

def internal_to_y(q):
    r1, r2, r3, s4, s5, tau = q
    rho = tau + jnp.pi/2
    y1 = 1 - jnp.exp(-a_morse * (r1 - q0[0]))
    y2 = 1 - jnp.exp(-a_morse * (r2 - q0[1]))
    y3 = 1 - jnp.exp(-a_morse * (r3 - q0[2]))
    y4 = s4
    y5 = s5
    y6 = jnp.sin(rho)
    return jnp.array([y1, y2, y3, y4, y5, y6])

def y_to_internal(y):
    y1, y2, y3, y4, y5, y6 = y
    r1 = -jnp.log(1 - y1) / a_morse + q0[0]
    r2 = -jnp.log(1 - y2) / a_morse + q0[1]
    r3 = -jnp.log(1 - y3) / a_morse + q0[2]
    s4 = y4
    s5 = y5
    rho = np.pi / 2 - acos(y6)  # asin(y6)
    tau = rho - np.pi/2
    return jnp.array([r1, r2, r3, s4, s5, tau])

@jax.jit
def poten_y(y):
    q = y_to_internal(y)
    return poten(q)

# masses of N, H1, H2, H3
masses = [14.00307400, 1.007825035, 1.007825035, 1.007825035]

@com(masses)
def internal_to_cartesian(internal_coords):
    r1, r2, r3, s4, s5, tau = internal_coords
    rho = tau + np.pi / 2

    beta1 = jnp.sqrt(6) / 3 * s4 + 2 * np.pi / 3
    beta2 = -1 / jnp.sqrt(6) * s4 + 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3
    beta3 = -1 / jnp.sqrt(6) * s4 - 1 / jnp.sqrt(2) * s5 + 2 * np.pi / 3

    cos_rho = jnp.cos(rho)
    sin_rho = jnp.sin(rho)
    cos_beta2 = jnp.cos(beta2)
    cos_beta3 = jnp.cos(beta3)
    sin_beta2 = jnp.sin(beta2)
    sin_beta3 = jnp.sin(beta3)

    xyz = jnp.array(
        [
            [0.0, 0.0, 0.0],
            [r1 * sin_rho, 0.0, r1 * cos_rho],
            [r2 * sin_rho * cos_beta3, r2 * sin_rho * sin_beta3, r2 * cos_rho],
            [r3 * sin_rho * cos_beta2, -r3 * sin_rho * sin_beta2, r3 * cos_rho],
        ]
    )
    return xyz

In [5]:
y0 = internal_to_y(q0)
xyz = internal_to_cartesian(q0)

print("Reference values of internal coordinates:\n", q0)
print("Reference values of expansion y-coordinates:\n", y0)
print("Reference values of Cartesian coordinates:\n", xyz)

Reference values of internal coordinates:
 [1.01159999e+00 1.01159999e+00 1.01159999e+00 6.95763752e-09
 2.12715030e-09 3.85722364e-01]
Reference values of expansion y-coordinates:
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 6.95763752e-09
 2.12715030e-09 9.26526900e-01]
Reference values of Cartesian coordinates:
 [[-2.61385144e-10 -2.97063299e-10  6.75834408e-02]
 [ 9.37274605e-01 -2.97063299e-10 -3.13009332e-01]
 [-4.68637300e-01  8.11703622e-01 -3.13009333e-01]
 [-4.68637301e-01 -8.11703618e-01 -3.13009331e-01]]


In [6]:
def internal_to_z(q):
    r1, r2, r3, s4, s5, tau = q
    rho = tau + jnp.pi / 2
    z1 = r1 - q0[0]
    z2 = r2 - q0[1]
    z3 = r3 - q0[2]
    z4 = s4
    z5 = s5
    z6 = jnp.cos(rho)
    return jnp.array([z1, z2, z3, z4, z5, z6])

def z_to_internal(z):
    z1, z2, z3, z4, z5, z6 = z
    r1 = z1 + q0[0]
    r2 = z2 + q0[1]
    r3 = z3 + q0[2]
    s4 = z4
    s5 = z5
    rho = acos(z6)
    tau = rho - np.pi / 2
    return jnp.array([r1, r2, r3, s4, s5, tau])


Find equilibrium values of internal coordinates

In [7]:
# set reference tau to 0, i.e., planar molecular geometry
q0_keo = np.copy(q0)
q0_keo[-1] = 0
# internal-to-Cartesian coordinate transformation
# y-coordinates for expansion of KEO

z0 = internal_to_z(q0_keo)
print("Reference values of expansion y-coordinates:\n", z0)

Reference values of expansion y-coordinates:
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 6.95763752e-09
 2.12715030e-09 6.12323400e-17]


Generate expansion power indices

In [8]:
ncoo = len(q0)
max_order = 6
powers = [np.arange(max_order + 1)] * ncoo
deriv_ind, deriv_mind = next(
    generate_prod_ind(powers, select=lambda ind: np.sum(ind) <= max_order)
)

print("max expansion power:", max_order)
print("number of expansion terms:", len(deriv_ind))

max expansion power: 6
number of expansion terms: 924


Generate expansion of PES in terms of internal coordinates

In [9]:
poten_file = f"nh3_poten_beta_y_coefs_{max_order}.npz"
if os.path.exists(poten_file):
    print(f"load potential expansion coefs from file {poten_file}")
    data = np.load(poten_file)
    poten_coefs = data['coefs']
else:
    print(f"calculate potential expansion coefs and save to {poten_file}")
    poten_coefs = deriv_list(poten_in_y, deriv_ind, y0, if_taylor=True)
    np.save(poten_file, poten_coefs)

load potential expansion coefs from file nh3_poten_beta_y_coefs_6.npz


In [10]:
gmat_file = f"nh3_gmat_beta_y_coefs_{max_order}.npz"
if os.path.exists(gmat_file):
    print(f"load G-matrix expansion coefs from file {gmat_file}")
    data = np.load(gmat_file)
    gmat_coefs = data['coefs']
else:
    print(f"calculate G-matrix expansion coefs and save to {poten_file}")
    gmat_coefs = deriv_list(
        lambda z: Gmat(z_to_internal(z), masses, internal_to_cartesian),
        deriv_ind,
        z0,
        if_taylor=True,
    )
    np.save(gmat_file, gmat_in_y_coefs)


load G-matrix expansion coefs from file nh3_gmat_beta_y_coefs_6.npz


In [11]:
mask = deriv_ind != 0
ind0 = np.where(mask.sum(axis=1) == 0)[0][0]
mu = np.diag(gmat_coefs[ind0])[:ncoo]

ind2 = np.array(
    [
        np.where((mask.sum(axis=1) == 1) & (deriv_ind[:, icoo] == 2))[0][0]
        for icoo in range(ncoo)
    ]
)
freq = poten_coefs[ind2] * 2

lin_a = jnp.sqrt(jnp.sqrt(mu / freq))
lin_b = q0

list_herm = [0,1,2]
list_leg = [3,4]
list_fourier = [5]

limits_r = [[0,jnp.inf],[0,jnp.inf],[0,jnp.inf],[-2.5,3.5],[-2.5,2.5],[-1.5,1.5]] #[-1.2,1.2]

lin_a = jnp.array([jnp.sqrt(jnp.sqrt(mu[i] / freq[i])) if i in list_herm
                    else (limits_r[i][1]-limits_r[i][0])/np.pi if i in list_leg
                    else (limits_r[i][1]-limits_r[i][0])/(2*np.pi)
                    for i in range(ncoo)])

lin_b = jnp.array([q0[i] if i in list_herm
                    else limits_r[i][0] if i in list_leg
                    else limits_r[i][0]
                    for i in range(ncoo)])

print("x->r linear mapping parameters 'a':", lin_a)
print("x->r linear mapping parameters 'b':", lin_b)

# x->r linear mapping function
x_to_r_map = lambda x, icoo: lin_a[icoo] * x + lin_b[icoo]
jac_x_to_r_map = lambda x, icoo: np.ones_like(x) * lin_a[icoo]

#Functions for change of coordinates for Taylor expansion
r_to_y_pes_ = lambda x, icoo: internal_to_y(jnp.array(q0).at[icoo].set(x))[icoo] - y0[icoo]
r_to_z_gmat_ = lambda x, icoo: internal_to_z(jnp.array(q0_keo).at[icoo].set(x))[icoo] - z0[icoo]

r_to_y_pes = jax.vmap(r_to_y_pes_,in_axes=(0,None))
r_to_z_gmat = jax.vmap(r_to_z_gmat_,in_axes=(0,None))


x->r linear mapping parameters 'a': [0.14152789 0.14152789 0.14152789 1.90985932 1.59154943 0.47746483]
x->r linear mapping parameters 'b': [ 1.01159999  1.01159999  1.01159999 -2.5        -2.5        -1.5       ]


In [12]:
nbas = [100] * ncoo
npoints = [121] * ncoo

p0, p1, p2, p3, p4, p5 = [
    HermiteBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_z_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo])
    if icoo in list_herm else
    AssociateLegendreBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_z_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo], m=1)
    if icoo in list_leg else
    FourierBasis(
        icoo, nbas[icoo], npoints[icoo], lambda x, icoo=icoo: x_to_r_map(x, icoo),
        lambda x, icoo=icoo: r_to_y_pes(x, icoo), lambda x, icoo=icoo: r_to_z_gmat(x, icoo),
        q0[icoo], deriv_ind[:, icoo])
    for icoo in range(ncoo)
]


In [13]:
sum_deriv_ind = np.sum(deriv_ind,axis=1)
poten_coefs_1d = lambda icoo: np.array([poten_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else 0.0 for i in range(len(sum_deriv_ind))])
gmat_coefs_1d = lambda icoo: np.array([gmat_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else np.zeros_like(gmat_coefs[i]) for i in range(len(sum_deriv_ind))])

c0, c1, c2, c3, c4, c5 = [
    ContrBasis(
        (icoo,), (p0, p1, p2, p3, p4, p5), lambda _: True, gmat_coefs_1d(icoo), poten_coefs_1d(icoo), store_int = True, 
    )
    for icoo in range(ncoo)
]

Number of states: 100
batch no: 0  out of: 1
iteration time: 0.41
[ 1774.91355813  5213.491785    8512.59377321 11674.42472321
 14700.75739516 17592.97817448 20352.13635446 22978.99049318
 25474.04499778 27837.58473124]
Saving integrals...
batch no: 0 out of: 1
iteration time compute objects to store: 0.25
Storing time: 0.06
Number of states: 100
batch no: 0  out of: 1
iteration time: 0.46
[ 1774.91356261  5213.49274159  8512.59420405 11674.42520216
 14700.75838262 17592.97830686 20352.13636062 22978.99012116
 25474.04495198 27837.58468009]
Saving integrals...
batch no: 0 out of: 1
iteration time compute objects to store: 0.06
Storing time: 0.09
Number of states: 100
batch no: 0  out of: 1
iteration time: 0.66
[ 1774.91355278  5213.49170717  8512.59380298 11674.42549047
 14700.75814475 17592.97827889 20352.13627381 22978.99057172
 25474.0449734  27837.58470803]
Saving integrals...
batch no: 0 out of: 1
iteration time compute objects to store: 0.04
Storing time: 0.02
Number of states: 1

In [14]:
# couple CH1 and CH2
p_coefs = np.array([1, 1, 1, 1, 1, 1])
pmax = 60
emax_trunc_ = 20000

e0,e1,e2,e3,e4,e5 = c0.enr,c1.enr,c2.enr,c3.enr,c4.enr,c5.enr
e_sum = [e0 - e0[0],e1 - e1[0],e2 - e2[0],e3 - e3[0],e4 - e4[0],e5 - e5[0],]
#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax

c012 = ContrBasis(
    (0, 1, 2),
    (c0, c1, c2, c3, c4, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    emax=30000,
    batch_size = 10000, #73441
    store_int = True,
)

e012 = c012.enr
print(e012[0], e012[0:5] - e012[0])

Number of states: 77
batch no: 0  out of: 1
iteration time: 0.5
[ 6352.46854271  9705.49128091  9832.33614128  9836.41590088
 12990.45210104 13071.11289154 13075.27157002 13237.16394349
 13289.85183478 13290.02096885]
Saving integrals...
batch no: 0 out of: 1
iteration time compute objects to store: 0.63
Storing time: 0.26
6352.468542708762 [   0.         3353.0227382  3479.86759857 3483.94735818 6637.98355833]


In [17]:
#This is the code that is slow for saving integrals 
#The first iteration is fast, and then it is slow...

# couple CH1 and CH2
p_coefs = np.array([1, 1, 1, 1, 1, 1])
pmax = 60
emax_trunc_ = 40000

#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax

c34 = ContrBasis(
    (3, 4),
    (c0, c1, c2, c3, c4, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    emax=30000,
    batch_size = 10000,
    store_int = True,
)

e34 = c34.enr
print(e34[0], e34[0:5] - e34[0])

Number of states: 443
batch no: 0  out of: 20
iteration time: 0.25
batch no: 1  out of: 20
iteration time: 0.23
batch no: 2  out of: 20
iteration time: 0.23
batch no: 3  out of: 20
iteration time: 0.22
batch no: 4  out of: 20
iteration time: 0.25
batch no: 5  out of: 20
iteration time: 0.24
batch no: 6  out of: 20
iteration time: 0.24
batch no: 7  out of: 20
iteration time: 0.25
batch no: 8  out of: 20
iteration time: 0.24
batch no: 9  out of: 20
iteration time: 0.23
batch no: 10  out of: 20
iteration time: 0.23
batch no: 11  out of: 20
iteration time: 0.23
batch no: 12  out of: 20
iteration time: 0.23
batch no: 13  out of: 20
iteration time: 0.24
batch no: 14  out of: 20
iteration time: 0.25
batch no: 15  out of: 20
iteration time: 0.23
batch no: 16  out of: 20
iteration time: 0.21
batch no: 17  out of: 20
iteration time: 0.22
batch no: 18  out of: 20
iteration time: 0.23
batch no: 19  out of: 20
iteration time: 0.17
[4626.15882857 6266.98745669 6266.98745822 7889.96094544 7912.266229

In [16]:
e5 = c5.enr

print('e012',e012[0], e012[0:5] - e012[0])
print('e34',e34[0], e34[0:5] - e34[0])
print('e5',e5[0], e5[0:5] - e5[0])

e012 6352.468542708762 [   0.         3353.0227382  3479.86759857 3483.94735818 6637.98355833]
e34 4626.1588285659 [   0.         1640.82862812 1640.82862965 3263.80211687 3286.10740114]
e5 524.3068674305788 [0.00000000e+00 1.10356225e+00 9.34336945e+02 9.82886179e+02
 1.59345183e+03]


In [20]:
# couple all together
p_coefs = np.array([1, 1, 1])
pmax = 300
emax_trunc_ = 10000
e_sum = [e012 - e012[0],e34 - e34[0],e5 - e5[0]]
#Make energy truncation a part of basis generation for speed.
f_e_sum = lambda ind: np.sum(np.array([e_sum[i][ind[i]] for i in range(len(ind))])) < emax_trunc_
f_pmax = lambda ind: np.sum(np.array(ind) * p_coefs[:len(ind)]) <= pmax
c = ContrBasis(
    (0, 1, 2),
    (c012, c34, c5),
    f_e_sum or f_pmax,
    gmat_coefs,
    poten_coefs,
    store_int = False,
    emax=25000,
    batch_size = 100000,
)
e = c.enr
print(e[0], e[0:20] - e[0])


Number of states: 581
batch no: 0  out of: 4
iteration time: 1.15
batch no: 1  out of: 4
iteration time: 0.76
batch no: 2  out of: 4
iteration time: 0.64
batch no: 3  out of: 4
iteration time: 0.62
[7434.73512217 7436.78284844 8279.45047522 8360.34413322 8886.78269244
 9054.36949106 9054.37019057 9057.78098363 9057.78169251 9273.54336993]
7434.735122166902 [0.00000000e+00 2.04772628e+00 8.44715353e+02 9.25609011e+02
 1.45204757e+03 1.61963437e+03 1.61963507e+03 1.62304586e+03
 1.62304657e+03 1.83880825e+03 2.35574088e+03 2.41699477e+03
 2.41699571e+03 2.52719601e+03 2.52719727e+03 2.91256976e+03
 3.02812040e+03 3.02812189e+03 3.20051619e+03 3.20544083e+03]


In [None]:
from vibrojet.basis_utils2 import fourgauss
sum_deriv_ind = np.sum(deriv_ind,axis=1)
poten_coefs_1d = lambda icoo: np.array([poten_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else 0.0 for i in range(len(sum_deriv_ind))])
gmat_coefs_1d = lambda icoo: np.array([gmat_coefs[i] if deriv_ind[i,icoo]==sum_deriv_ind[i] else np.zeros_like(gmat_coefs[i]) for i in range(len(sum_deriv_ind))])
icoo = 5
deriv_ind_icoo = deriv_ind[:,icoo]
gmat_c = gmat_coefs_1d(icoo)
poten_c = poten_coefs_1d(icoo)

npoints = 200
x, w = fourgauss(npoints)
r = x_to_r_map(x,icoo)
z = r_to_z_gmat(r,icoo)
y = r_to_y_pes(r,icoo) #- y0[icoo]

z_pow = z[:, None] ** np.array(deriv_ind_icoo)[None, :]
y_pow = y[:, None] ** np.array(deriv_ind_icoo)[None, :]

poten_1d_y = jnp.einsum('i,gi->g',poten_c,y_pow)
gmat_1d_z = jnp.einsum('i,gi->g',gmat_c[:,icoo,icoo],z_pow)

q0_jnp = jnp.array(q0)
poten_1d = jax.vmap(lambda x: poten(q0_jnp.at[5].set(x)))
# plt.plot(r,poten_1d(r))
# plt.plot(r,poten_1d_y+100)
# plt.xlim([-1.0,1.0])
# plt.ylim([0,5000])

q0_keo_jnp = jnp.array(q0_keo)
gmat_1d = jax.vmap(lambda x: Gmat(q0_keo_jnp.at[5].set(x), masses, internal_to_cartesian))
#plt.plot(r,gmat_1d(r)[:,icoo,icoo])
#plt.plot(r,gmat_1d_z)