In [291]:
import jax
from jax import numpy as jnp
from jax import config
config.update("jax_enable_x64", True)

# relative error will be "less in magnitude than r" 
r = 1.0e-15

@jax.jit 
def rf(x, y, z):
    
    xyz = jnp.array([x, y, z])
    A0 = jnp.sum(xyz) / 3.0
    v = jnp.max(jnp.abs(A0 - xyz))
    Q = (3 * r) ** (-1 / 6) * v

    cond = lambda s: s['f'] * Q > jnp.abs(s['An'])

    def body(s):

        xyz = s['xyz']
        lam = (
            jnp.sqrt(xyz[0]*xyz[1]) 
            + jnp.sqrt(xyz[0]*xyz[2]) 
            + jnp.sqrt(xyz[1]*xyz[2])
        )

        s['An'] = 0.25 * (s['An'] + lam)
        s['xyz'] = 0.25 * (s['xyz'] + lam)
        s['f'] = s['f'] * 0.25

        return s

    s = {'f': 1, 'An':A0, 'xyz':xyz}
    s = jax.lax.while_loop(cond, body, s)

    x = (A0 - x) / s['An'] * s['f']
    y = (A0 - y) / s['An'] * s['f']
    z = -(x + y)
    E2 = x * y - z * z
    E3 = x * y * z

    return (
        1 
        - 0.1 * E2 
        + E3 / 14 
        + E2 * E2 / 24 
        - 3 * E2 * E3 / 44
    ) / jnp.sqrt(s['An'])

@jax.jit 
def rf_unrolled(x, y, z):
    
    xyz = jnp.array([x, y, z])
    A0 = jnp.sum(xyz) / 3.0
    sqr_xyz = jnp.sqrt(xyz)
    An = A0

    for i in range(5):

        lam = (
            sqr_xyz[0] * (sqr_xyz[1] + sqr_xyz[2]) 
            + sqr_xyz[1] * sqr_xyz[2]
        )

        An = 0.25 * (An + lam)
        sqr_xyz = jnp.sqrt(0.25 * (sqr_xyz**2 + lam))

    lam = (
        sqr_xyz[0] * (sqr_xyz[1] + sqr_xyz[2]) 
        + sqr_xyz[1] * sqr_xyz[2]
    )

    m = 1 / (An + lam) * 0.0009765625

    x = (A0 - x) * m
    y = (A0 - y) * m
    z = -(x + y)
    E2 = x * y - z * z
    E3 = x * y * z

    return (
        1 
        #- 0.1 * E2 
        + E3 / 14 
        + E2 * (E2 / 24 - 3 * E3 / 44 - 0.1)
    ) / jnp.sqrt(An)

In [292]:
0.25 ** 5

0.0009765625

In [294]:
rf_unrolled(0.1, 0.2, 0.3)

Array(2.29880483, dtype=float64)

In [296]:
n = 100000
%timeit jax.vmap(rf)(0.1 * jnp.ones(n), 0.2 * jnp.ones(n), 0.3 * jnp.ones(n))
%timeit jax.vmap(rf_unrolled)(0.1 * jnp.ones(n), 0.2 * jnp.ones(n), 0.3 * jnp.ones(n))

11.6 ms ± 661 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.56 ms ± 957 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


2.2988048918835324


In [267]:
import numpy as np
x = np.random.rand(10000)
y = np.random.rand(10000)
z = np.random.rand(10000)
jnp.max(jax.vmap(rf)(x, y, z) - jax.vmap(rf_unrolled)(x, y, z))
#jnp.max(jax.vmap(rf, in_axes=(0, 0, None))(x, y, 0.000000001) - jax.vmap(rf_unrolled, in_axes=(0, 0, None))(x, y, 0.000000001))

Array(1.77635684e-15, dtype=float64)