In [None]:
import scipy
import numpy as np
from matplotlib import pyplot as plt

In [None]:
x = np.linspace(0,10,1000)

In [None]:
x = np.tile(np.linspace(0,10,1000), (1000,1)) + np.tile(np.linspace(-5j,5j,1000).T, (1000,1)).T
x

In [None]:
y = scipy.special.spherical_jn(37, x)
fig, axs = plt.subplots(1,2)
axs[0].imshow(y.real)
axs[1].imshow(y.imag)

In [None]:
for n in range(37):
    plt.imshow(
        x,
        scipy.special.spherical_jn(n,x)
    )

In [None]:
for n in range(37):
    plt.plot(
        x,
        scipy.special.spherical_jn(n,x, True)
    )

In [None]:
scipy.special.spherical_jn(0,x+0.2j) - scipy.special.spherical_jn(n,x)

In [None]:
import jax
import jax.numpy as jnp
from functools import partial

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
@partial(jax.jit, static_argnums=(1,))
def sb(z,n=10):
    j0 = lambda z: jnp.sin(z) / z
    j1 = lambda z: jnp.sin(z) / z**2 - jnp.cos(z) / z
    res = [j0(z), j1(z)]
    for i in range(2, n):
        new = (2*(i-1)+1)/z * res[-1] - res[-2]
        res.append(new)
    return res

In [None]:
def generate_bessel_interpolators(energies, v0i, lmax,
                                  intpol_deg=3, n_steps=1000,
                                  max_vib_amp=2, max_geo_disp=2):
    n_energies = len(energies)
    all_l = jnp.arange(lmax, dtype=jnp.int32)
    # vibrational amplitudes
    vib_amps = jnp.linspace(0, max_vib_amp, n_steps)
    all_vib_amp = scipy.special.spherical_jn(
        all_l[:, np.newaxis, np.newaxis],
        np.broadcast_to(np.einsum('r,e->re', -2/3*vib_amps**2, energies),
                                  (lmax, n_steps, n_energies))
    )
    return all_vib_amp
    # geometries
    abs_displacements = jnp.linspace(0, max_geo_disp, n_steps)
    kappa = jnp.sqrt(2*energies - v0i)

In [None]:
generate_bessel_interpolators(jnp.array([1,2,3]), 1.0, 10).shape

In [None]:
import scipy.special


scipy.special.spherical_jn([0,1,3], [1.2, 2.2, 2.2])

In [None]:
jnp.arange(3, dtype=jnp.int32)[:, np.newaxis, np.newaxis]

In [None]:
dr = jnp.linspace(0, 1, 10)
en = jnp.linspace(0, 10, 10)

In [None]:
jnp.broadcast_to(jnp.einsum('r,e->re', dr, en), (37, 10, 10))

In [None]:
jnp.einsum('r,e->re', dr, en).shape