In [24]:
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Array
from diffrax import ODETerm, Dopri5, diffeqsolve, SaveAt
import matplotlib.pyplot as plt

### Define relevant constants

In [None]:
omega_0 = jnp.sqrt(0.44022)
e0 = 0.95289
viscosity = 1

U = 1e4  # finite upper bound for numerical integration over [0, infty]
n_disc = int(1e6)  # number of discretization points per integral


eps = 1e-5
e = 0.94
a_0 = jnp.array(
    [
        (1 - e**2) ** (-1 / 6) * (1 + eps),
        (1 - e**2) ** (-1 / 6) / (1 + eps),
        (1 - e**2) ** (1 / 3),
    ]
)

T = 100

### Set up numerical index symbols and corresponding derivative

In [26]:
def A(a: Array) -> Array:
    assert a.shape == (3,)
    u = jnp.expand_dims(jnp.linspace(0, U, n_disc), axis=1)

    y = (a[0] * a[1] * a[2]) / (
        (a**2 + u) * jnp.sqrt((a[0] ** 2 + u) * (a[1] ** 2 + u) * (a[2] ** 2 + u))
    )
    return jsp.integrate.trapezoid(y.T, u.flatten())

In [29]:
%time A(jnp.array([1, 2, 4]))

CPU times: user 11.3 ms, sys: 28.3 ms, total: 39.7 ms
Wall time: 13 ms


Array([1.2057469 , 0.56955755, 0.22469594], dtype=float32)