In [1]:
import jax
jax.config.update("jax_enable_x64", True)
from bulirsch import * 
from scipy.special import ellipe, ellipk
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import quad

In [None]:
def el1_numerical(x, kc):

    func = lambda y: np.sqrt(1 / (np.cos(y)**2 + kc**2 * np.sin(y)**2))
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el1_numerical)(x, 0.2), '-', alpha=0.5, linewidth=5)
plt.plot(x, el1(x, 0.2), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el1_numerical)(0.1, kc), '-', alpha=0.5, linewidth=5)
plt.plot(x, el1(0.1, kc), 'k--')

In [None]:
def el2_numerical(x, kc, a, b):

    func = lambda y: ((a + b * np.tan(y)**2) 
                      / jnp.sqrt(
                          (1 + np.tan(y)**2) 
                          * (1 + kc**2 * np.tan(y)**2)
                      )
                     )
    
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el2_numerical)(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, el2(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el2_numerical)(0.1, kc, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, el2(0.1, kc, 0.1, 0.5), 'k--')

In [None]:
def el3_numerical(x, kc, p):

    func = lambda y: (
        1 / (np.cos(y)**2 + p * np.sin(y)**2) 
        / np.sqrt(np.cos(y)**2 + kc * kc * np.sin(y)**2)
    )
    
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el3_numerical)(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, el3(x, 0.2, 0.1), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el3_numerical)(0.1, kc, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, el3(0.1, kc, 0.1), 'k--')

In [None]:
def cel_numerical(kc, p, a, b):

    func = lambda y: (
        (a * np.cos(y)**2 + b * np.sin(y)**2) / (np.cos(y)**2 + p * np.sin(y)**2)
        / np.sqrt(np.cos(y)**2 + kc**2 * np.sin(y)**2)
    )

    return quad(func, 0, np.pi / 2)[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(cel_numerical)(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, cel(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
kc = np.linspace(0, 2, 100)
plt.plot(kc, np.vectorize(cel_numerical)(0.1, kc, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(kc, cel(0.1, kc, 0.1, 0.5), 'k--')

In [None]:
x = np.linspace(-5, 5, 5000)
%timeit elliprc(x, 0.2)
rc(x, 0.2)
%timeit rc(x, 0.2)
plt.plot(x, elliprc(x, 0.2), '-', alpha=0.5, linewidth=5)
plt.plot(x, rc(x, 0.2), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprc(0.1, x), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rc(0.1, x), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprd(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, rd(x, 0.2, 0.1), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprd(0.1, x, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rd(0.1, x, 0.1), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprf(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, rf(x, 0.2, 0.1), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprf(0.1, x, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rf(0.1, x, 0.1), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprj(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, rj(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprj(0.1, x, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rj(0.1, x, 0.1, 0.5), 'k--')

In [None]:
jax.grad(el2, (0, 1, 2, 3))(0.5, 0.1, 0.5, 0.4)

In [2]:
@jax.jit 
@jnp.vectorize
def el2_n(x, kc, a, b):
    r"""JAX implementation of Bulirsch's el2 integral

    Computed using the algorithm in Bulirsch, 1969b: https://doi.org/10.1007/BF02165405 

    .. math::

       \[\operatorname{el2}\left(x,k_{c},a,b\right)=\int_{0}^{\operatorname{arctan}x}%
\frac{a+b{\tan}^{2}\theta}{\sqrt{(1+{\tan}^{2}\theta)(1+k_{c}^{2}{\tan}^{2}%
\theta)}}\,\mathrm{d}\theta.\]

     Args:
       x: arraylike, real valued.
       kc: arraylike, real valued.
       a: arraylike, real valued.
       b: arraylike, real valued.

     Returns:
       The value of the integral el2

     Notes:
       ``el2`` does not support complex-valued inputs.
       ``el2`` requires `jax.config.update("jax_enable_x64", True)`
    """

    D = 15.0
    ca = 10**(-D / 2.0)
    cb = 10**(-D + 2.0)

    c = x**2
    dd = c + 1.0
    p = jnp.sqrt((1.0 + kc**2 * c) / dd)
    dd = x / dd
    c = dd * 0.5 / p
    z = a - b
    ik = a
    a = (b + a) * 0.5
    y = jnp.abs(1 / x)
    f = 0.0
    kc = jnp.abs(kc)
    m = 1.0

    def cont():

        s = {
            'l': 0.0,
            'b': ik * kc + b,
            'e': m * kc,
            'g': m * kc / p,
            'dd': f * m * kc / p + dd,
            'f': c,
            'ik': a,
            'p': m * kc / p + p,
            'c': ((f * m * kc / p + dd) / (m * kc / p + p) + c) * 0.5,
            'g': m,
            'm': kc + m,
            'a': ((ik * kc + b) / (kc + m) + a) * 0.5,
            'y': - (m * kc / y) + y,
            'kc': kc
        }
        
        s['y'] = jax.lax.cond(
            s['y'] == 0, 
            lambda: jnp.sqrt(s['e']) * cb, 
            lambda: s['y']
        )

        def cond_fun(s):
            
            return jnp.abs(s['g'] - s['kc']) > ca * s['g']

        def body_fun(s):

            s['kc'] = 2 * jnp.sqrt(s['e'])
            s['l'] = 2 * s['l']
            s['l'] = s['l'] + (s['y'] < 0)

            s['b'] = s['ik'] * s['kc'] + s['b']
            s['e'] = s['m'] * s['kc']
            s['g'] = s['e'] / s['p']
            s['dd'] = s['f'] * s['g'] + s['dd']
            s['f'] = s['c']
            s['ik'] = s['a']
            s['p'] = s['g'] + s['p']
            s['c'] = (s['dd'] / s['p'] + s['c']) * 0.5
            s['g'] = s['m']
            s['m'] = s['kc'] + s['m']
            s['a'] = (s['b'] / s['m'] + s['a']) * 0.5
            s['y'] = - (s['e'] / s['y']) + s['y']
            s['y'] = jax.lax.cond(
                s['y'] == 0, 
                lambda: jnp.sqrt(s['e']) * cb, 
                lambda: s['y']
            )

            return s

        s = jax.lax.while_loop(cond_fun, body_fun, s)
        s['l'] = s['l'] + (s['y'] < 0)
        s['e'] = (jnp.arctan(s['m'] / s['y']) + jnp.pi * s['l']) * s['a'] / s['m']
        # this line is slightly different from the algorithm in gefera (see ellip.f90 line 127)
        # but it matches the numerical integral... 
        s['e'] = -(2 * (x < 0) - 1) * s['e'] + s['c'] * z

        return s['e']

    e = jax.lax.cond(
        kc == 0, 
        lambda: jnp.sin(jnp.arctan(x)), 
        cont
    )

    return e

In [3]:
jax.jacfwd(el2, (0, 1, 2, 3))(0.5, 0.1, 0.5, 0.4)

(Array(0.4287894, dtype=float64, weak_type=True),
 Array(4.10643656, dtype=float64, weak_type=True),
 Array(0.44704379, dtype=float64, weak_type=True),
 Array(0.03397374, dtype=float64, weak_type=True))

In [4]:
jax.jacfwd(el2_n, (0, 1, 2, 3))(0.5, 0.1, 0.5, 0.4)

(Array(0.4287894, dtype=float64, weak_type=True),
 Array(-0.0018919, dtype=float64, weak_type=True),
 Array(0.44704379, dtype=float64, weak_type=True),
 Array(0.03397374, dtype=float64, weak_type=True))