In [1]:
# JAX
import jax
import jax.numpy as jnp

In [2]:
# Immutability
array = jnp.ones(10)
#array[0] = 0

# do this instead (returns a copy):
# modified_array = array.at[0].set(0)
# modified_array

In [3]:
rng = jax.random.key(42)
a_key, b_key = jax.random.split(rng)

a = jax.random.uniform(a_key, shape=1_000_000, minval=5, maxval=10)
b = jax.random.uniform(b_key, shape=1_000_000, minval=10, maxval=20)


def quadratic_formula(a, b):
  return (-b + jax.numpy.sqrt(b**2 - 4*a)) / (2*a)


# JAX represents a program internally with a "JaxPr".
# JaxPr is an intermediate representation (IR) inspired
# by functional programming languages (~Lisp-y)
#   lambda, no mutation, operator arg1 arg2 ...
jax.make_jaxpr(quadratic_formula)(a, b)

{ lambda ; a:f32[1000000] b:f32[1000000]. let
    c:f32[1000000] = neg b
    d:f32[1000000] = integer_pow[y=2] b
    e:f32[1000000] = mul 4.0 a
    f:f32[1000000] = sub d e
    g:f32[1000000] = sqrt f
    h:f32[1000000] = add c g
    i:f32[1000000] = mul 2.0 a
    j:f32[1000000] = div h i
  in (j,) }

In [4]:
# this IR can be transformed, e.g.:
fast_quadratic_formula = jax.jit(quadratic_formula)
jax.make_jaxpr(fast_quadratic_formula)(a, b)  # closure!

{ lambda ; a:f32[1000000] b:f32[1000000]. let
    c:f32[1000000] = pjit[
      name=quadratic_formula
      jaxpr={ lambda ; d:f32[1000000] e:f32[1000000]. let
          f:f32[1000000] = neg e
          g:f32[1000000] = integer_pow[y=2] e
          h:f32[1000000] = mul 4.0 d
          i:f32[1000000] = sub g h
          j:f32[1000000] = sqrt i
          k:f32[1000000] = add f j
          l:f32[1000000] = mul 2.0 d
          m:f32[1000000] = div k l
        in (m,) }
    ] a b
  in (c,) }

In [5]:
print("\nOriginal:")
%timeit -n 10 quadratic_formula(a, b)


print("\nJIT (with compile-time):")
%timeit -n 1 fast_quadratic_formula(a, b).block_until_ready()

print("\nJIT (without compile-time):")
%timeit -n 10 fast_quadratic_formula(a, b).block_until_ready()


# why is that? -> intermediate arrays
def pedantic_quadratic_formula(a, b):
  tmp1 = jnp.negative(b)            # -b
  tmp2 = jnp.square(b)              # b**2
  tmp3 = jnp.multiply(4, a)         # 4*a
  tmp4 = jnp.subtract(tmp2, tmp3)   # tmp2 - tmp3
  del tmp2, tmp3
  tmp5 = jnp.sqrt(tmp4)             # sqrt(tmp4)
  del tmp4
  tmp6 = jnp.add(tmp1, tmp5)        # tmp1 + tmp5
  del tmp1, tmp5
  tmp7 = jnp.multiply(2, a)         # 2*a
  return jnp.divide(tmp6, tmp7)     # tmp6 / tmp7


print("\nOriginal - written out:")
%timeit -n 10 pedantic_quadratic_formula(a, b)


Original:
The slowest run took 5.04 times longer than the fastest. This could mean that an intermediate result is being cached.
16.1 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

JIT (with compile-time):
The slowest run took 90.34 times longer than the fastest. This could mean that an intermediate result is being cached.
13.8 ms ± 31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

JIT (without compile-time):
1.08 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Original - written out:
10.4 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
# Eager
print(jax.make_jaxpr(quadratic_formula)(1.0, 2.0))

print()

# Gradient: transform program (JaxPr) into the differentiated version of it
grad_quadratic_formula = jax.grad(quadratic_formula)
print(jax.make_jaxpr(grad_quadratic_formula)(1.0, 2.0))

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = neg b
    d:f32[] = integer_pow[y=2] b
    e:f32[] = mul 4.0 a
    f:f32[] = sub d e
    g:f32[] = sqrt f
    h:f32[] = add c g
    i:f32[] = mul 2.0 a
    j:f32[] = div h i
  in (j,) }

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = neg b
    d:f32[] = integer_pow[y=2] b
    e:f32[] = mul 4.0 a
    f:f32[] = sub d e
    g:f32[] = sqrt f
    h:f32[] = div 0.5 g
    i:f32[] = add c g
    j:f32[] = mul 2.0 a
    _:f32[] = div i j
    k:f32[] = integer_pow[y=-2] j
    l:f32[] = mul 1.0 k
    m:f32[] = mul l i
    n:f32[] = neg m
    o:f32[] = div 1.0 j
    p:f32[] = mul 2.0 n
    q:f32[] = mul o h
    r:f32[] = neg q
    s:f32[] = mul 4.0 r
    t:f32[] = add_any p s
  in (t,) }
