### JAX's JIT
1. A function must be pure, i.e. the output of the function should not be conditioned on the values of the input.

In [7]:
import numpy as onp
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap, make_jaxpr, random

@jit
def cross_matrix(v):
  print("Compile")
  flatten_v = jnp.reshape(v, (-1,))
  cross_mat = jnp.array(
      [[0.0, -flatten_v[2], flatten_v[1]],
       [flatten_v[2], 0.0, -flatten_v[0]],
       [-flatten_v[1], flatten_v[0], 0.0]])
  return cross_mat

a = onp.array([1,2,3])
b = onp.array([2,3,4])

%time cross_matrix(a)
%time c = cross_matrix(b)

print(onp.asarray(c))

Compile
CPU times: user 637 ms, sys: 0 ns, total: 637 ms
Wall time: 667 ms
CPU times: user 3.05 ms, sys: 426 µs, total: 3.47 ms
Wall time: 3.01 ms
[[ 0. -4.  3.]
 [ 4.  0. -2.]
 [-3.  2.  0.]]


2. If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation. Note that the static arguments will trigger a recompilation if it's changed. However, both the original and the new implementations are all stored and thus can be retrieved for fast computation

In [6]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    print("Compile")
    return -x if neg else x

%time f(1, True)
%time f(1, True)

%time f(1, False)
%time f(1, False)

%time f(1, True)

Compile
CPU times: user 10.7 ms, sys: 0 ns, total: 10.7 ms
Wall time: 8.77 ms
CPU times: user 405 µs, sys: 0 ns, total: 405 µs
Wall time: 278 µs
Compile
CPU times: user 677 µs, sys: 0 ns, total: 677 µs
Wall time: 640 µs
CPU times: user 262 µs, sys: 0 ns, total: 262 µs
Wall time: 269 µs
CPU times: user 48 µs, sys: 0 ns, total: 48 µs
Wall time: 54.1 µs


DeviceArray(-1, dtype=int32, weak_type=True)

3. Automatic vectorization is another very useful feature of jax

In [16]:
key = jax.random.PRNGKey(100)

a = jax.random.normal(key)
print(a)

k1, k2 = jax.random.split(key)
print(jax.random.normal(k1))
print(jax.random.normal(k2))

k1, k2 = jax.random.split(k1)
print(jax.random.normal(k1))
print(jax.random.normal(k2))

-0.9812892
-0.55406797
-0.4960655
-1.1210011
-1.5675282


In [24]:
key = jax.random.PRNGKey(100)
ks = jax.random.split(key, 20)  # we want to generate 20 draws
print(ks.shape)
draws = vmap(random.normal)(ks)
print(draws.shape)

(20, 2)
(20,)


In [27]:
import cv2
import numpy as onp
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap, make_jaxpr, random

key = jax.random.PRNGKey(100)

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

v = random.normal(key, (10,))
%time selu(v).block_until_ready()

selu_jit = jax.jit(selu)
selu_jit(v)
%time selu_jit(v).block_until_ready()

img = cv2.imread("plant.bmp", cv2.IMREAD_GRAYSCALE)
img2 = onp.array(img[0,0:10])

print(v.shape)
print(img2.shape)

%time selu_jit(img2).block_until_ready()

CPU times: user 30.2 ms, sys: 3.99 ms, total: 34.2 ms
Wall time: 33 ms
CPU times: user 53 µs, sys: 0 ns, total: 53 µs
Wall time: 58.7 µs
(10,)
(10,)
CPU times: user 24.5 ms, sys: 0 ns, total: 24.5 ms
Wall time: 24.2 ms


DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32, weak_type=True)