### JAX's JIT
1. A function must be pure, i.e. the output of the function should not be conditioned on the values of the input.

In [36]:
import numpy as onp
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap, make_jaxpr

@jit
def cross_matrix(v):
  print("Compile")
  flatten_v = jnp.reshape(v, (-1,))
  cross_mat = jnp.array(
      [[0.0, -flatten_v[2], flatten_v[1]],
       [flatten_v[2], 0.0, -flatten_v[0]],
       [-flatten_v[1], flatten_v[0], 0.0]])
  return cross_mat

a = onp.array([1,2,3])
b = onp.array([2,3,4])

%time cross_matrix(a)
%time c = cross_matrix(b)

print(onp.asarray(c))

Compile
CPU times: user 58.1 ms, sys: 1.54 ms, total: 59.6 ms
Wall time: 58.7 ms
CPU times: user 110 µs, sys: 0 ns, total: 110 µs
Wall time: 81.1 µs
[[ 0. -4.  3.]
 [ 4.  0. -2.]
 [-3.  2.  0.]]


2. If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation. Note that the static arguments will trigger a recompilation if it's changed. However, both the original and the new implementations are all stored and thus can be retrieved for fast computation

In [37]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    print("Compile")
    return -x if neg else x

%time f(1, True)
%time f(1, True)

%time f(1, False)
%time f(1, False)

%time f(1, True)

Compile
CPU times: user 13.4 ms, sys: 0 ns, total: 13.4 ms
Wall time: 14 ms
CPU times: user 67 µs, sys: 0 ns, total: 67 µs
Wall time: 73 µs
Compile
CPU times: user 1.13 ms, sys: 0 ns, total: 1.13 ms
Wall time: 1.23 ms
CPU times: user 475 µs, sys: 0 ns, total: 475 µs
Wall time: 487 µs
CPU times: user 282 µs, sys: 29 µs, total: 311 µs
Wall time: 258 µs


DeviceArray(-1, dtype=int32, weak_type=True)