In [None]:
%load_ext autoreload
%autoreload 2 

from jax import numpy as jnp
import jax
from jax.typing import ArrayLike, DTypeLike
from jax import Array


jax.config.update("jax_enable_x64", True)
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1)]

In [None]:
def U3v1(chi: ArrayLike, alpha: DTypeLike) -> Array:
    return jax.lax.cond(
        alpha > 0,
        lambda: (jnp.sqrt(alpha) * chi - jnp.sin(jnp.sqrt(alpha) * chi))
        / alpha
        / jnp.sqrt(alpha),
        lambda: jax.lax.cond(
            alpha < 0,
            lambda: (jnp.sqrt(-alpha) * chi - jnp.sinh(jnp.sqrt(-alpha) * chi))
            / alpha
            / jnp.sqrt(-alpha),
            lambda: chi**3 / 6.0,
        ),
    )


def U3v2(chi: ArrayLike, alpha: DTypeLike) -> Array:
    alpha_c = jax.lax.complex(alpha, 0.0)
    return jax.lax.cond(
        alpha == 0.0,
        lambda _: chi**3 / 6.0,
        lambda _: jax.lax.real(
            (jnp.sqrt(alpha_c) * chi - jnp.sin(jnp.sqrt(alpha_c) * chi))
            / alpha_c
            / jnp.sqrt(alpha_c)
        ),
        alpha,
    )


def U3v3(chi: ArrayLike, alpha: DTypeLike) -> Array:
    return jnp.where(
        alpha > 0,
        (jnp.sqrt(alpha) * chi - jnp.sin(jnp.sqrt(alpha) * chi))
        / alpha
        / jnp.sqrt(alpha),
        jnp.where(
            alpha < 0,
            (jnp.sqrt(-alpha) * chi - jnp.sinh(jnp.sqrt(-alpha) * chi))
            / alpha
            / jnp.sqrt(-alpha),
            chi**3 / 6.0,
        ),
    )


U3v1_jit = jax.jit(U3v1)
U3v2_jit = jax.jit(U3v2)
U3v3_jit = jax.jit(U3v3)

In [24]:
chi = jnp.pi
alpha = jnp.array([0.0, 1.0, -1.0])

result1 = jax.vmap(U3v1_jit, in_axes=(None, 0))(chi, alpha)
print(f"Result 1: {result1}")
result2 = jax.vmap(U3v2_jit, in_axes=(None, 0))(chi, alpha)
print(f"Result 2: {result2}")
result3 = U3v3_jit(chi, alpha)
print(f"Result 3: {result3}")


grad1 = jax.vmap(jax.grad(U3v1), in_axes=(None, 0))(chi, alpha)
print(f"Grad 1: {grad1}")
grad2 = jax.vmap(jax.grad(U3v2), in_axes=(None, 0))(chi, alpha)
print(f"Grad 2: {grad2}")
grad3 = jax.vmap(jax.grad(U3v3), in_axes=(None, 0))(chi, alpha)
print(f"Grad 3: {grad3}")

# Results show that U3v3 does not work with jax.grad

Result 1: [5.16771278 3.14159265 8.4071467 ]
Result 2: [5.16771278 3.14159265 8.4071467 ]
Result 3: [5.16771278 3.14159265 8.4071467 ]
Grad 1: [ 4.9348022   2.         10.59195328]
Grad 2: [ 4.9348022   2.         10.59195328]
Grad 3: [nan nan nan]


In [None]:
# --- Prepare benchmark data ---
key = jax.random.PRNGKey(0)
key_chi, key_alpha = jax.random.split(key)

# Create a large array with 100,000 elements
N = 1_000_00
chi = jax.random.uniform(key_chi, (N,)) * jnp.pi * 2

# Case 1: All alpha > 0
alpha_pos = jax.random.uniform(key_alpha, (N,))

# Case 2: All alpha < 0
alpha_neg = -jax.random.uniform(key_alpha, (N,))

# Case 3: Mixed positive and negative alpha
alpha_mix = jax.random.normal(key_alpha, (N,))

# Ensure JIT compilation is complete (warm-up)
jax.vmap(U3v1_jit, in_axes=(0, 0))(chi[:10], alpha_pos[:10]).block_until_ready()
jax.vmap(U3v2_jit, in_axes=(0, 0))(chi[:10], alpha_pos[:10]).block_until_ready()

Array([3.24633870e+01, 1.07914181e+00, 6.30361776e+00, 1.30161151e+01,
       8.48328728e-01, 1.55862004e+00, 3.84984327e+00, 1.15542347e+00,
       8.34801777e-03, 4.98232697e+00], dtype=float64)

In [42]:
# --- Test cosistency ---
result1 = jax.vmap(U3v1_jit, in_axes=(0, 0))(chi, alpha_mix)
result2 = jax.vmap(U3v2_jit, in_axes=(0, 0))(chi, alpha_mix)
if jnp.allclose(result1, result2):
    print("✅ U3v1 and U3v2 are consistent for mixed alpha.")
else:
    print("❌ U3v1 and U3v2 are Inconsistent results for mixed alpha.")

✅ U3v1 and U3v2 are consistent for mixed alpha.


In [44]:
print("Timing U3v1 for positive alpha...")
%timeit jax.vmap(U3v1_jit, in_axes=(0, 0))(chi, alpha_pos).block_until_ready()
print("Timing U3v2 for positive alpha...")
%timeit jax.vmap(U3v2_jit, in_axes=(0, 0))(chi, alpha_pos).block_until_ready()

Timing U3v1 for positive alpha...
439 μs ± 4.43 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Timing U3v2 for positive alpha...
450 μs ± 7.78 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
print("Timing U3v1 for negative alpha...")
%timeit jax.vmap(U3v1_jit, in_axes=(0, 0))(chi, alpha_neg).block_until_ready()
print("Timing U3v2 for negative alpha...")
%timeit jax.vmap(U3v2_jit, in_axes=(0, 0))(chi, alpha_neg).block_until_ready()

Timing U3v1 for negative alpha...
417 μs ± 4.96 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Timing U3v2 for negative alpha...
418 μs ± 2.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [46]:
print("Timing U3v1 for negative alpha...")
%timeit jax.vmap(U3v1_jit, in_axes=(0, 0))(chi, alpha_mix).block_until_ready()
print("Timing U3v2 for negative alpha...")
%timeit jax.vmap(U3v2_jit, in_axes=(0, 0))(chi, alpha_mix).block_until_ready()

Timing U3v1 for negative alpha...
415 μs ± 1.71 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Timing U3v2 for negative alpha...
415 μs ± 5.24 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
