In [None]:
from functools import lru_cache
import numpy as np
from jax.scipy.special import sph_harm
import jax
import jax.numpy as jnp
from functools import partial
import scipy
import timeit
from matplotlib import pyplot as plt

# Spherical Bessel functions from NeuralIL
from otftleed.spherical_bessel import functions

# numerical epsilon to avoid division by zero
EPS = 1e-8
bessel_EPS = 10e-7

from tqdm import tqdm

In [None]:
def bessel_scipy(z, n1, derivative=False):
    return scipy.special.spherical_jn(np.arange(n1), z, derivative)

In [None]:
def bessel_tenserleed(Z, N1):
    BJ = []
    # Initialize complex variables
    ZSQ = Z**2 / 2.0

    # L is indeed the angular momentum in question.

    # L dependent prefactor
    PRE = 1.0 + 0j

    for L in range(N1):
        # K is index of series expansion.

        def body_fun(args):
            K, TERM, SUM = args
            TERM = -ZSQ / (K * (2 * L + 2 * K + 1)) * TERM
            SUM = SUM + TERM
            return K + 1, TERM, SUM
        # Repeat the loop until TERM is really small
        def cond_fun(args):
            K, TERM, SUM = args
            return abs(TERM) > 1e-16
        K, TERM, SUM = jax.lax.while_loop(cond_fun,
                           body_fun, (1.0, 1.0+0j, 1.0+0j))


        # Evaluate j_L(Z), i.e., BJ[L]
        BJ.append((PRE * SUM))

        # Update PRE for the next L
        PRE *= Z / (2.0 * (L + 1) + 1.0)
    return jnp.asarray(BJ)

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit
from functools import partial

from jax.scipy.special import gamma

from jax import config
config.update("jax_enable_x64", True)  


from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
   check_arraylike, promote_dtypes_inexact, _where)
from jax._src.custom_derivatives import custom_jvp

_lax_const = lax._const

def jint(n):
    return jnp.trunc(n).astype('int')

def spb1(x: ArrayLike, /) -> Array:
    '''
    Calculate the spherical Bessel functions j_1(z). 
    Follows existing implementation of jnp.sinc for safety around 0, 
    using a Maclaurin series to keep continuous derivatives.

    Arguments:
        x: The argument of the function.

    Returns:
        csj: The function j_1(z).
    '''

    check_arraylike("spb1", x)
    x, = promote_dtypes_inexact(x)

    # not defined at zero
    eq_zero = lax.eq(x, _lax_const(x, 0))
    
    safe_x = _where(eq_zero, _lax_const(x, 1), x)
    return _where(eq_zero, _spb1_maclaurin(0, x),
                    lax.div(lax.sin(safe_x), safe_x**2)-lax.div(lax.cos(safe_x), safe_x))

@partial(custom_jvp, nondiff_argnums=(0,))
def _spb1_maclaurin(k, x):
  # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
  # compute the monomial term in the jvp rule)
  # TODO(mattjj): see https://github.com/google/jax/issues/10750
  if k % 2:
    return x * 0
  else:
    top = 1.j* (1.j/2)**k * jnp.sqrt(jnp.pi)
    bottom = (k-1) * gamma(2+k/2) * gamma(0.5*(k-1))
    return x * 0 + jnp.real(top / bottom)

@_spb1_maclaurin.defjvp
def _spb1_maclaurin_jvp(k, primals, tangents):
  (x,), (t,) = primals, tangents
  return _spb1_maclaurin(k, x), _spb1_maclaurin(k + 1, x) * t


def envj(n, x):
    '''
    Helper function for msta1 and msta2.

    '''
    envj = 0.5 * jnp.log10(6.28 * n) - n * jnp.log10(1.36 * x / n) # always true 

    return envj

def msta1(x, mp):
    ''' 
    Calculate the number of terms required for the spherical Bessel function.
    '''
    a0 = jnp.abs(x)
    n0 = jint(1.1 * a0) + 1
    f0 = envj(n0, a0) - mp
    n1 = n0 + 5
    f1 = envj(n1, a0) - mp

    nn = jint(n1 - (n1 - n0) / (1.0 - f0 / f1))
    f = envj(nn, a0) - mp
    n0 = n1
    f0 = f1
    n1 = nn
    f1 = f
    diff = jnp.abs(nn - n1)

    def cond_fun(inputs):
        n0, f0, n1, f1, nn, counter, diff = inputs
        return jnp.logical_and(jnp.abs(diff) > 1, counter < 20)

    def body_fun(inputs):
        n0, f0, n1, f1, nn, diff, counter = inputs
        nn = jint(n1 - (n1 - n0) / (1.0 - f0 / f1))
        diff = nn - n1
        f = envj(nn, a0) - mp
        n0 = n1
        f0 = f1
        n1 = nn
        f1 = f
        counter += 1
        return n0, f0, n1, f1, nn, diff, counter

    n0, f0, n1, f1, nn, diff, _ = jax.lax.while_loop(cond_fun, body_fun, (n0, f0, n1, f1, nn, diff, 0))

    return nn

def msta2(x, n, mp):
    ''' 
    Calculate the number of terms required for the spherical Bessel function.
    '''
    a0 = jnp.abs(x)
    hmp = 0.5 * mp
    ejn = envj(n, a0)

    obj, n0 = jax.lax.cond(ejn <= hmp, 
                       lambda _: (mp*1.0, jint(1.1 * a0) + 1), 
                       lambda _: (hmp + ejn, jint(n)), 
                       operand=None)

    f0 = envj(n0, a0) - obj
    n1 = n0 + 5
    f1 = envj(n1, a0) - obj

    nn = jint(n1 - (n1 - n0) / (1.0 - f0 / f1))

    def cond_fun(inputs):
        n0, f0, n1, f1, nn, diff, counter = inputs
        return jnp.logical_and(jnp.abs(diff) >= 1, counter < 20)

    def body_fun(inputs):
        n0, f0, n1, f1, nn, diff, counter = inputs
        nn = jint(n1 - (n1 - n0) / (1.0 - f0 / f1))
        diff = nn - n1
        f = envj(nn, a0) - obj
        n0 = n1
        f0 = f1
        n1 = nn
        f1 = f
        counter += 1
        return n0, f0, n1, f1, nn, diff, counter
    
    n0, f0, n1, f1, nn, diff, _ = jax.lax.while_loop(cond_fun, body_fun, (n0, f0, n1, f1, nn, nn-n1, 0))

    return nn + 10

@partial(custom_jvp, nondiff_argnums=(0,))
@partial(jit,static_argnums=0)
def csphjy(n, z):
    ''' 
    Spherical Bessel functions of the first and second kind, and their derivatives.
    Follows the implementation of https://github.com/emsr/maths_burkhardt/blob/master/special_functions.f90, but with the derivatives.
    Arguments:
        n: The order of the spherical Bessel function.
        z: The argument of the function.    
    Returns:
        nm: The number of terms used in the calculation.
        csj: The function j_n(z).
        cdj: The derivative of the function j_n(z).
        csy: The function y_n(z).
        cdy: The derivative of the function y_n(z).
        
    '''
    a0 = jnp.abs(z)
    nm = n
    complex = jax.dtypes.canonicalize_dtype(jnp.complex128)
    csj = jnp.zeros(n+1, dtype=complex)
    csj = csj.at[0].set(jnp.sinc(z /jnp.pi))
    csj = csj.at[1].set(spb1(z))

    if n >= 2:
        csa = csj[0]
        csb = csj[1]
        m = msta1(a0, 200)

        m, nm = jax.lax.cond(m < n, 
                     lambda _: (m, m), 
                     lambda _: (msta2(a0, n, 15), n), 
                     operand=None)

        cf0 = jnp.asarray(0.0, dtype=z)
        cf1 = jnp.asarray(1.0e-100, dtype=z)
        cf = (2.0 * m + 3.0) * cf1 / z - cf0

        def body_fun(kk, inputs):
            k = m - kk
            cf, csj, cf0, cf1 = inputs
            cf = (2.0 * k + 3.0) * cf1 / z - cf0
            def true_fun(csj):
                return csj.at[k].set(cf)
            csj = jax.lax.cond(k <= nm, true_fun, lambda csj: csj, csj)
            cf0 = cf1
            cf1 = cf
            return cf, csj, cf0, cf1

        cf, csj, cf0, cf1 = jax.lax.fori_loop(0, m+1, body_fun, (cf, csj, cf0, cf1))

        cs = jax.lax.cond(jnp.abs(csa) <= jnp.abs(csb), 
                  lambda _: csb / cf0, 
                  lambda _: csa / cf, 
                  operand=None)
        

        csj = cs * csj

    return csj

@csphjy.defjvp
def csphjy_jvp(n, primals, tangents):
    z, = primals
    z_dot, = tangents
    csj = csphjy(n,z)
    
    cdj = jnp.zeros(n+1, dtype=complex)
    cdj = cdj.at[0].set((jnp.cos(z) - jnp.sin(z) / z) / z)
    cdj = cdj.at[1:].set(csj[:-1] - (jnp.arange(1, len(csj)) + 1.0) * csj[1:] / z)
    return csj, cdj*z_dot


def maketriples_all(mask,verbose=False):
    """ returns int array of triple hole indices (0-based), 
        and float array of two uv vectors in all triangles
    """
    nholes = mask.shape[0]
    tlist = []
    for i in range(nholes):
        for j in range(nholes):
            for k in range(nholes):
                if i < j and j < k:
                    tlist.append((i, j, k))
    tarray = np.array(tlist).astype(np.int32)
    if verbose:
        print("tarray", tarray.shape, "\n", tarray)

    tname = []
    uvlist = []
    # foreach row of 3 elts...
    for triple in tarray:
        tname.append("{0:d}_{1:d}_{2:d}".format(
            triple[0], triple[1], triple[2]))
        if verbose:
            print('triple:', triple, tname[-1])
        uvlist.append((mask[triple[0]] - mask[triple[1]],
                       mask[triple[1]] - mask[triple[2]]))
    # print(len(uvlist), "uvlist", uvlist)
    if verbose:
        print(tarray.shape, np.array(uvlist).shape)
    return tarray, np.array(uvlist)

def makebaselines(mask):
    """
    ctrs_eqt (nh,2) in m
    returns np arrays of eg 21 baselinenames ('0_1',...), eg (21,2) baselinevectors (2-floats)
    in the same numbering as implaneia
    """
    nholes = mask.shape[0]
    blist = []
    for i in range(nholes):
        for j in range(nholes):
            if i < j:
                blist.append((i, j))
    barray = np.array(blist).astype(np.int32)
    # blname = []
    bllist = []
    for basepair in blist:
        # blname.append("{0:d}_{1:d}".format(basepair[0],basepair[1]))
        baseline = mask[basepair[0]] - mask[basepair[1]]
        bllist.append(baseline)
    return barray, np.array(bllist)

# Check results

In [None]:
# Maximum order of spherical Bessel functions needed for the calculation
max_order = 37

In [None]:
# scalar test
z = 2.0 + 3.0j

In [None]:
# real valued test array
z = np.linspace(1e-100, 10, 1000)

# complex valued test array
z = z + 1j * z

In [None]:
# Harmonix
bessel_harmonix = jax.jit(csphjy, static_argnums=0)
bessel_harmonix_v = jax.jit(jax.vmap(csphjy, in_axes=(None,0)), static_argnums=0)
bessel_harmonix_jac = jax.jit(jax.jacrev(csphjy, argnums=1, holomorphic=True), static_argnums=0)
bessel_harmonix_v_jac = jax.jit(jax.vmap(bessel_harmonix_jac, in_axes=(None,0)), static_argnums=0)


In [None]:
# TensErLEED Series expansion
bessel_tenserleed = jax.jit(bessel_tenserleed, static_argnums=1)
bessel_tenserleed_v = jax.jit(jax.vmap(bessel_tenserleed, in_axes=(0,None)), static_argnums=1)
bessel_tenserleed_jac = jax.jit(jax.jacfwd(bessel_tenserleed, argnums=0, holomorphic=True), static_argnums=1)
bessel_tenserleed_v_jac = jax.jit(jax.vmap(bessel_tenserleed_jac, in_axes=(0,None)), static_argnums=1)

In [None]:
def _generate_bessel_functions(l_max):
    """Generate a list of spherical Bessel functions up to order l_max"""
    bessel_functions = []
    for order in range(l_max+1):
        bessel_functions.append(jax.jit(functions.create_j_l(order)))
    return bessel_functions


# generate a list of spherical Bessel functions up to order l_max
BESSEL_FUNCTIONS = _generate_bessel_functions(max_order)

In [None]:
# NeuralIL Single L
def bessel_neuralil_base(z, n1):
    """Spherical Bessel functions. Evaluated at z, up to degree n1."""
    return BESSEL_FUNCTIONS[n1](z)

bessel_neuralil = jax.jit(bessel_neuralil_base, static_argnums=1)
bessel_neuralil_v = jax.jit(jax.vmap(bessel_neuralil_base, in_axes=(0,None)), static_argnums=1)
bessel_neuralil_jac = jax.jit(jax.jacrev(bessel_neuralil_base, argnums=0, holomorphic=True), static_argnums=1)
bessel_neuralil_v_jac = jax.jit(jax.vmap(bessel_neuralil_jac, in_axes=(0,None)), static_argnums=1)

In [None]:
# NeuralIL all L
def bessel_neuralil_series_base(z, n1):
    return jnp.asarray([BESSEL_FUNCTIONS[order](z) for order in range(n1)])

bessel_neuralil_series = jax.jit(bessel_neuralil_series_base, static_argnums=1)
bessel_neuralil_series_v = jax.jit(jax.vmap(bessel_neuralil_series_base, in_axes=(0,None)), static_argnums=1)
bessel_neuralil_series_jac = jax.jit(jax.jacrev(bessel_neuralil_series_base, argnums=0, holomorphic=True), static_argnums=1)
bessel_neuralil_series_v_jac = jax.jit(jax.vmap(bessel_neuralil_series_jac, in_axes=(0,None)), static_argnums=1)

In [None]:
def custom_spherical_jn(n, z):
    return jax.lax.switch(n, BESSEL_FUNCTIONS, z)

def bessel_neuralil_select_base(z, n1):
    """Spherical Bessel functions. Evaluated at z, up to degree n1."""
    vmapped_custom_bessel = jax.vmap(custom_spherical_jn, (0, None))
    return vmapped_custom_bessel(jnp.arange(n1), z)


bessel_neuralil_select = jax.jit(bessel_neuralil_select_base, static_argnums=1)
bessel_neuralil_select_v = jax.jit(jax.vmap(bessel_neuralil_select_base, in_axes=(0,None)), static_argnums=1)
bessel_neuralil_select_jac = jax.jit(jax.jacrev(bessel_neuralil_select_base, argnums=0, holomorphic=True), static_argnums=1)
bessel_neuralil_select_v_jac = jax.jit(jax.vmap(bessel_neuralil_select_jac, in_axes=(0,None)), static_argnums=1)

# Timinings

In [None]:
from numpy import array
def time_functions(z, suffix, max_order, n_repeats):
    scipy_time = []
    harmonix_time = []
    neural_il_series_time = []
    neural_il_select_time = []
    neural_il_single_time = []
    tenserleed_time = []
    
    z = repr(z)

    for order in tqdm(range(1, max_order)):
        derivative = 'jac' in suffix
        scipy_time.append(
            timeit.timeit(
            f"scipy.special.spherical_jn({order}, {z}, {derivative})", globals=globals(),
            number=n_repeats)/n_repeats
        )

        cmd = f"bessel_neuralil_series{suffix}({z}, {order}).block_until_ready()"
        neural_il_series_time.append(
            timeit.timeit(
                cmd, globals=globals(), number=n_repeats)/n_repeats
        )

        cmd = f"bessel_neuralil_select{suffix}({z}, {order}).block_until_ready()"
        neural_il_select_time.append(
            timeit.timeit(
                cmd, globals=globals(), number=n_repeats)/n_repeats
        )

        cmd = f"bessel_neuralil{suffix}({z}, {order}).block_until_ready()"
        neural_il_single_time.append(
            timeit.timeit(
                cmd, globals=globals(), number=n_repeats)/n_repeats
        )

        cmd = f"bessel_harmonix{suffix}({order}, {z}).block_until_ready()"
        harmonix_time.append(
            timeit.timeit(
                cmd, globals=globals(),number=n_repeats)/n_repeats
        )

        cmd = f"bessel_tenserleed{suffix}({z}, {order}).block_until_ready()"
        tenserleed_time.append(
            timeit.timeit(
                cmd, globals=globals(),number=n_repeats)/n_repeats
        )
    return scipy_time, harmonix_time, neural_il_series_time, neural_il_select_time, neural_il_single_time, tenserleed_time

In [None]:
def plot_results(times, ax, title=""):
    scipy_time, harmonix_time, neural_il_series_time, neural_il_select_time, neural_il_single_time, tenserleed_time = times
    x = range(1, len(scipy_time)+1)

#     plt.plot(x, scipy_time,
#             ls='-', marker='s',
#             label='scipy')
    ax.plot(x, harmonix_time,
            ls='-', marker='s',
            label='harmonix')
    ax.plot(x, neural_il_series_time,
            ls='-', marker='s',
            label='neural_il_series')
    ax.plot(x, neural_il_select_time,
            ls='-', marker='s',
            label='neural_il_select')
    ax.plot(x, neural_il_single_time,
            ls='-', marker='s',
            label='neural_il_single')
    ax.plot(x, tenserleed_time,
            ls='-', marker='s',
            label='tenserleed')
    ax.set_yscale('log')
    ax.set_title(title)
    ax.legend()


In [None]:
z_scalar = 2.0 + 3.0j
z_arr = np.linspace(1e-8, 2.0, 1000) + 1j*np.linspace(1e-8, 2.0, 1000)

In [None]:
suffix = ""

# compile
compile_times = time_functions(z_scalar, suffix, max_order, n_repeats=1)
# execute
times = time_functions(z_scalar, suffix, max_order, n_repeats=200)

## Vectorized

In [None]:
suffix = "_v"

# compile
compile_times_v = time_functions(z_arr, suffix, max_order, n_repeats=1)
# execute
times_v = time_functions(z_arr, suffix, max_order, n_repeats=200)

## Scalar-valued Derivative

In [None]:
suffix = "_jac"

# compile
compile_times_jac = time_functions(z_scalar, suffix, max_order, n_repeats=1)
# execute
times_jac = time_functions(z_scalar, suffix, max_order, n_repeats=200)

## Vector-valued Derivative

In [None]:
suffix = "_v_jac"

# compile
compile_times_v_jac = time_functions(z_arr, suffix, max_order, n_repeats=1)
# execute
times_v_jac = time_functions(z_arr, suffix, max_order, n_repeats=200)

# Plots

In [None]:
fig, axs = plt.subplots(2, 2, sharex=True, figsize=(12,8))
fig.suptitle('Compile time')
fig.supxlabel('Order l')
fig.supylabel('Time (s)')
plot_results(compile_times, axs[0, 0], title="Scalar")
plot_results(compile_times_v, axs[0, 1], title="Vectorized")
plot_results(compile_times_jac, axs[1, 0], title="Scalar-valued Derivative")
plot_results(compile_times_v_jac, axs[1, 1], title="Vectorized Derivative")
fig.savefig("bessel_functions_compile_times.pdf")

In [None]:
fig, axs = plt.subplots(2, 2, sharex=True, figsize=(12,8))
fig.suptitle('Execution time')
fig.supxlabel('Order l')
fig.supylabel('Time (s)')
plot_results(times, axs[0, 0], title="scalar valued argument")
plot_results(times_v, axs[0, 1], title="vector valued argument")
plot_results(times_jac, axs[1, 0], title="derivative for scalar valued argument")
plot_results(times_v_jac, axs[1, 1], title="derivative for vector valued argument")
fig.savefig("bessel_functions_execution_times.pdf")

In [None]:
np.array(times[-2])/ np.array(times[0])

# Accuracy

In [None]:
from numpy import array
def benchmark_functions(z, suffix, max_order):
    harmonix_err = []
    neuralil_err = []
    tenserleed_err = []
    
    z = repr(z)

    for order in tqdm(range(1, max_order)):
        derivative = 'jac' in suffix
        scipy_value = scipy.special.spherical_jn(order, z, derivative)

        neuralil_err.append(
            eval(f"bessel_neuralil_series{suffix}({z}, {order}")
        )


        harmonix_err.append(
            eval(f"bessel_harmonix{suffix}({order}, {z})")
        )

        tenserleed_err.append(
            eval(f"bessel_tenserleed{suffix}({z}, {order})")
        )
    return harmonix_err, neuralil_err, tenserleed_err


In [None]:
harmonix_err = []
neuralil_err = []
tenserleed_err = []


for order in tqdm(range(1, max_order)):
    derivative = 'jac' in suffix
    scipy_value = scipy.special.spherical_jn(order, z_arr, derivative)

    neuralil_err.append(
        np.max(abs(bessel_neuralil_series_v(z_arr, order)[:,order] - scipy_value))
    )


    harmonix_err.append(
        np.max(abs(bessel_harmonix_v(order, z_arr)[:,order] - scipy_value))
    )

    tenserleed_err.append(
        np.max(abs(bessel_tenserleed_v(z_arr, order)[:,order] - scipy_value))
    )


In [None]:
harmonix_err_jac = []
neuralil_err_jac = []
tenserleed_err_jac = []


for order in tqdm(range(1, max_order)):
    derivative = 'jac' in suffix
    scipy_value = scipy.special.spherical_jn(order, z_arr, derivative=True)

    neuralil_err_jac.append(
        np.max(abs(bessel_neuralil_series_v_jac(z_arr, order)[:,order] - scipy_value))
    )


    harmonix_err_jac.append(
        np.max(abs(bessel_harmonix_v_jac(order, z_arr)[:,order] - scipy_value))
    )

    tenserleed_err_jac.append(
        np.max(abs(bessel_tenserleed_v_jac(z_arr, order)[:,order] - scipy_value))
    )


In [None]:
fig, axs = plt.subplots(1, 2, sharex=True, figsize=(12,4))
fig.suptitle("Error vs. Scipy")

fig.supxlabel('Order l')
fig.supylabel('Deviation from Scipy')

axs[0].set_title("Value")
axs[0].plot(harmonix_err, label='Harmonix', ls='-', marker = 's')
axs[0].plot(neuralil_err, label='NeuralIL', ls='-', marker = 's')
axs[0].plot(tenserleed_err, label='TensErLEED like expansion', ls='--', marker = 'o')
axs[0].legend()
axs[0].set_yscale('log')

axs[1].set_title("Derivative")
axs[1].plot(harmonix_err_jac, label='Harmonix', ls='-', marker = 's')
axs[1].plot(neuralil_err_jac, label='NeuralIL', ls='-', marker = 's')
axs[1].plot(tenserleed_err_jac, label='TensErLEED like expansion', ls='--', marker = 'o')
axs[1].legend()
axs[1].set_yscale('log')

fig.savefig("bessel_functions_accuracy.pdf")

In [None]:
bessel_harmonix(10, 0.0+0j)
bessel_harmonix(10, 1e-8)

In [None]:
bessel_scipy(1e-5+1e-5j, 10)

In [None]:
A = np.random.rand(121, 121)

In [None]:
A

In [None]:
b = np.random.rand(121)

In [None]:
np.allclose(((A*b)@A ), np.einsum('ij,j,jk->ik', A, b, A))

In [None]:
A = jnp.asarray(A)
b = jnp.asarray(b)

In [None]:
@jax.jit
def matrix(A,b):
    return (A*b)@A

@jax.jit
def ijjjk(A,b):
    return jnp.einsum('ij,j,jk->ik', A, b, A, optimize=True)



In [None]:
%time matrix(A,b)
%time ijjjk(A,b)


In [None]:
%timeit -n 10000 matrix(A,b)
%timeit -n 10000 ijjjk(A,b)


In [None]:
np.allclose(ijjjk(A,b), matrix(A,b))