In [53]:
from jax import jacobian
import jax.numpy as jnp
import jax.scipy as jsp
from diffrax import ODETerm, Dopri5, diffeqsolve, SaveAt, AbstractTerm
from jaxtyping import Array, Int, Float
import matplotlib.pyplot as plt

# jax.config.update("jax_enable_x64", True)

### Define relevant constants

In [54]:
omega_0 = jnp.sqrt(0.44022)
e0 = 0.95289
viscosity = 1

U = 1e4  # finite upper bound for numerical integration over [0, infty]
n_disc = int(1e6)  # number of discretization points per integral

step_size = 1e-3


eps = 1e-3
e = 0.94
a_0 = jnp.array(
    [
        (1 - e**2) ** (-1 / 6) * (1 + eps),
        (1 - e**2) ** (-1 / 6) / (1 + eps),
        (1 - e**2) ** (1 / 3),
    ]
)

T = 100
a_0

Array([1.432546  , 1.4296852 , 0.48825982], dtype=float32)

### Set up numerical index symbols and corresponding derivative

In [55]:
def index_symbol(a: Array) -> Array:
    """
    Input: a [shape (3,)]
    Output: Index symbols A [shape (3,)]
    """
    assert a.shape == (3,)
    u = jnp.expand_dims(jnp.linspace(0, U, n_disc), axis=1)

    y = (a[0] * a[1] * a[2]) / (
        (a**2 + u) * jnp.sqrt((a[0] ** 2 + u) * (a[1] ** 2 + u) * (a[2] ** 2 + u))
    )
    return jsp.integrate.trapezoid(y.T, u.flatten())


# index_symbol(jnp.array([1, 2, 4]))

In [56]:
def d_index_symbol(a: Array) -> Array:
    """
    Input: a [shape (3,)]
    Output: derivative matrix dA [shape (3,3)]
    with dA[i,j] = \partial A_i / \partial a_j (0-indexed)
    """
    return jacobian(index_symbol)(a)


# d_index_symbol(jnp.array([1, 2, 4]))

### Define coefficients

In [57]:
def _Sigma(sign: Int, a: Array) -> Float:
    """
    Private helper for computing lambda and omega, defined such that
    Sigma(1) = (Lambda + Omega)^2, Sigma(-1) = (Lambda - Omega)^2
    """
    assert abs(sign) == 1
    A = index_symbol(a)
    return 2 * (
        (a[0] * A[0] - sign * a[1] * A[1]) / (a[0] - sign * a[1])
        + sign * (a[2] ** 2) * A[2] / (a[0] * a[1])
    )


def Lambda(a: Array) -> Float:  ### vorticity
    return (1 / 2) * (jnp.sqrt(_Sigma(1, a)) + jnp.sqrt(_Sigma(-1, a)))


def Omega(a: Array) -> Float:  ### angular velocity
    return (1 / 2) * (jnp.sqrt(_Sigma(1, a)) - jnp.sqrt(_Sigma(-1, a)))


# a = jnp.array([4, 2, 1])
# assert jnp.abs((Lambda(a) - Omega(a)) ** 2 - _Sigma(-1, a)) < eps
# assert jnp.abs((Lambda(a) + Omega(a)) ** 2 - _Sigma(1, a)) < eps

In [58]:
def _Q(alpha: Int, epsilon: Int, a: Array) -> Float:
    """
    Helper function for computing b_i coefficients (from Detweiler and Lindblom)
    """
    assert alpha in [1, 2]
    assert abs(epsilon) == 1
    A = index_symbol(a)
    dA = d_index_symbol(a)
    sgn = -1 if alpha == 1 and epsilon == -1 else 1
    alpha -= 1
    beta = alpha ^ 1

    return (
        A[alpha]
        + a[alpha] * dA[alpha, alpha]
        + epsilon * a[beta] * dA[beta, alpha]
        - a[2] * dA[alpha, 2]
        - a[2] * a[beta] * dA[beta, 2] / a[alpha]
        - (a[0] * A[0] + epsilon * a[1] * A[1]) / (a[0] + epsilon * a[1])
        + (
            sgn
            * (a[0] + epsilon * a[1])
            * (3 * A[2] + a[2] * dA[2, 2] - a[alpha] * dA[2, alpha])
            * (a[2] ** 2)
            / ((a[alpha] ** 2) * a[beta])
        )
        + 2 * ((Lambda(a) - epsilon * Omega(a)) ** 2)
    ) / (Lambda(a) - epsilon * Omega(a))


# _Q(1, -1, a_0)  # jnp.array([4.0, 2.0, 1.0]))
# the precision of the Q's vs mathematica was a good way to test the values of U, n_disc
# with a perturbation of 1e-5, U=1e4, n_disc=1e6 I was getting way off when epsilon=-1

In [None]:
def _b0(a: Array) -> Float:
    return (
        -5
        * Lambda(a)
        * ((a[0] ** 2) - (a[1] ** 2))
        / ((a[0] ** 2) * (a[1] ** 2))
        * (_Q(2, 1, a) * (a[0] + a[1]) + _Q(2, -1, a) * (a[0] - a[1]))
        / (_Q(1, 1, a) * _Q(2, -1, a) + _Q(1, -1, a) * _Q(2, 1, a))
    )


def _b1(a: Array) -> Float:
    return (
        -5
        * Lambda(a)
        * ((a[0] ** 2) - (a[1] ** 2))
        / ((a[0] ** 2) * (a[1] ** 2))
        * (_Q(1, -1, a) * (a[0] - a[1]) - _Q(1, 1, a) * (a[0] + a[1]))
        / (_Q(1, 1, a) * _Q(2, -1, a) + _Q(1, -1, a) * _Q(2, 1, a))
    )


def _b2(a: Array) -> Float:
    return -a[2] * (_b0(a) / a[0] + _b1(a) / a[1])


def _b(a: Array):
    return jnp.stack([_b0(a), _b1(a), _b2(a)])


# _b(a_0)

Array([ 2.8212243e-03, -2.8270718e-03,  3.9211145e-06], dtype=float32)

In [62]:
def coeffs(t, a, args):
    nu = args[0]
    return nu * _b(a)


term = ODETerm(coeffs)
solver = Dopri5()

solution = diffeqsolve(
    term,
    solver,
    t0=0.0,
    t1=T,
    dt0=step_size,
    y0=a_0,
    args=(viscosity,),
    saveat=SaveAt(ts=jnp.linspace(0, T, T * 100)),
)

# 5) Grab time-series
ts = solution.ts  # shape (200,)
ys = solution.ys  # shape (200, 2)

KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot(ts, ys[:, 0], label="a1")
plt.plot(ts, ys[:, 1], label="a2")
plt.xlabel("Time")
plt.ylabel("y")
plt.title("Solutions a1, a2 over time")
plt.legend()
plt.show()