In [None]:
# JIT-compilation - extreme case
import jax
import jax.numpy as jnp


rng = jax.random.key(42)
a_key, b_key = jax.random.split(rng)

a = jax.random.uniform(a_key, shape=1_000_000, minval=5, maxval=10)
b = jax.random.uniform(b_key, shape=1_000_000, minval=10, maxval=20)


def quadratic_formula(a, b):
  return (-b + jax.numpy.sqrt(b**2 - 4*a)) / (2*a)

fast_quadratic_formula = jax.jit(quadratic_formula)

print("\nOriginal:")
%timeit -n 10 quadratic_formula(a, b)


print("\nJIT (with compile-time):")
%timeit -n 1 fast_quadratic_formula(a, b).block_until_ready()

print("\nJIT (without compile-time):")
%timeit -n 10 fast_quadratic_formula(a, b).block_until_ready()



# why is that? -> intermediate arrays (adapted from: https://github.com/jpivarski-talks/2023-07-11-scipy2023-tutorial-thinking-in-arrays)
def pedantic_quadratic_formula(a, b, c):
  tmp1 = jnp.negative(b)            # -b
  tmp2 = jnp.square(b)              # b**2
  tmp3 = jnp.multiply(4, a)         # 4*a
  tmp4 = jnp.subtract(tmp2, tmp3)   # tmp2 - tmp3
  del tmp2, tmp3
  tmp5 = jnp.sqrt(tmp4)             # sqrt(tmp4)
  del tmp4
  tmp6 = jnp.add(tmp1, tmp5)        # tmp1 + tmp5
  del tmp1, tmp5
  tmp7 = jnp.multiply(2, a)         # 2*a
  return jnp.divide(tmp6, tmp7)     # tmp6 / tmp7


print("\nOriginal - written out:")
%timeit -n 10 pedantic_quadratic_formula(a, b, c)