In [None]:
from jax import numpy as jnp
from astrodynx.twobody.ivp import U3
import jax
from jax.typing import ArrayLike
from jax import Array


jax.config.update("jax_enable_x64", True)
jax.devices()

In [None]:
def U3v2(chi: ArrayLike, alpha: ArrayLike) -> Array:
    alpha_c = jax.lax.complex(alpha, 0.0)
    return jnp.where(
        alpha == 0.0,
        chi**3 / 6.0,
        jax.lax.real(
            (jnp.sqrt(alpha_c) * chi - jnp.sin(jnp.sqrt(alpha_c) * chi))
            / alpha_c
            / jnp.sqrt(alpha_c)
        ),
    )


# --- JIT compiled version ---
U3_jit = jax.jit(U3)
U3v2_jit = jax.jit(U3v2)

In [None]:
# --- Test code ---
def run_test():
    # Use a fixed random seed to ensure reproducibility
    key = jax.random.PRNGKey(0)

    # Generate test data
    # chi can be any value
    chi_values = jax.random.uniform(key, shape=(100,), minval=-10.0, maxval=10.0)

    # alpha values need to cover >0, <0, and =0 cases
    key, subkey = jax.random.split(key)
    alpha_pos = jax.random.uniform(subkey, shape=(50,), minval=0.1, maxval=20.0)
    key, subkey = jax.random.split(key)
    alpha_neg = jax.random.uniform(subkey, shape=(49,), minval=-20.0, maxval=-0.1)
    alpha_zero = jnp.array([0.0])
    alpha_values = jnp.concatenate([alpha_pos, alpha_neg, alpha_zero])

    # To test all combinations of chi and alpha, we create a grid
    chi_grid, alpha_grid = jnp.meshgrid(chi_values, alpha_values)

    print(f"Testing with chi grid shape: {chi_grid.shape}")
    print(f"Testing with alpha grid shape: {alpha_grid.shape}")

    result1 = U3_jit(chi_grid, alpha_grid)
    result2 = U3v2_jit(chi_grid, alpha_grid)

    # Check if results are "close enough"
    # jnp.allclose compares element by element, checking if differences are within tolerance
    are_they_close = jnp.allclose(result1, result2, atol=1e-6, rtol=1e-6)

    if are_they_close:
        print(
            "\n✅ Test passed: Results from both functions are consistent within tolerance."
        )
    else:
        print("\n❌ Test failed: Results from both functions are inconsistent.")
        # Find inconsistencies for debugging
        diff = jnp.abs(result1 - result2)
        max_diff = jnp.max(diff)
        print(f"Maximum difference: {max_diff}")

        # Find where the maximum difference occurs
        idx = jnp.unravel_index(jnp.argmax(diff), diff.shape)
        print(f"Maximum difference occurs at index {idx}")
        print(f"  - chi: {chi_grid[idx]}")
        print(f"  - alpha: {alpha_grid[idx]}")
        print(f"  - U3 result: {result1[idx]}")
        print(f"  - U3v2 result: {result2[idx]}")


run_test()

In [None]:
# --- Prepare benchmark data ---
key = jax.random.PRNGKey(0)
key_chi, key_alpha = jax.random.split(key)

# Create a large array with 100,000 elements
N = 1_000_00
chi = jax.random.uniform(key_chi, (N,)) * jnp.pi * 2

# Case 1: All alpha > 0
alpha_pos = jax.random.uniform(key_alpha, (N,))

# Case 2: All alpha < 0
alpha_neg = -jax.random.uniform(key_alpha, (N,))

# Case 3: Mixed positive and negative alpha
alpha_mix = jax.random.normal(key_alpha, (N,))

# Ensure JIT compilation is complete (warm-up)
U3_jit(chi[:10], alpha_pos[:10]).block_until_ready()
U3v2_jit(chi[:10], alpha_pos[:10]).block_until_ready()

In [None]:
print("Timing U3 and U3v2 for positive alpha...")
%timeit U3_jit(chi, alpha_pos).block_until_ready()

%timeit U3v2_jit(chi, alpha_pos).block_until_ready()

In [None]:
print("Timing U3 and U3v2 for negative alpha...")
%timeit U3_jit(chi, alpha_neg).block_until_ready()

%timeit U3v2_jit(chi, alpha_neg).block_until_ready()

In [None]:
print("Timing U3 and U3v2 for mixed alpha...")
%timeit U3_jit(chi, alpha_mix).block_until_ready()

%timeit U3v2_jit(chi, alpha_mix).block_until_ready()