Source: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

### Just In Time Compilation with JAX

In [None]:
import jax.numpy as jnp

##### Example 1

In [None]:
import jax

In [None]:
def _pow(x, n):
    return x ** n

In [None]:
pow(2, 2), pow(2, 3)

(4, 8)

In [None]:
x = jnp.array([1., 2., 3.])

In [None]:
x

Array([1., 2., 3.], dtype=float32)

In [None]:
jit_pow = jax.jit(pow)

In [None]:
jit_pow(2, 3)

Array(8, dtype=int32)

In [None]:
def f(x):
    if x > 0:
        return x
    else:
        return 2 * x

In [None]:
f_jit = jax.jit(f, static_argnums=1)

In [None]:
f_jit(10, 2)

TypeError: f() takes 1 positional argument but 2 were given

##### Example 3

In [None]:
import jax

Compile the function `func` to XLA-optimized machine code

In [None]:
def func(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [None]:
func_jit = jax.jit(func, static_argnames=['n'])

In [None]:
func_jit(10, 20)

Array(30, dtype=int32, weak_type=True)

##### Example 4

In [None]:
from functools import partial

In [None]:
def demo(x, n):
    return x + n

Jit the function `demo` using decorator

In [None]:
@partial(jax.jit, static_argnames=["n"])
def demo(x, n):
    return x + n

In [None]:
demo(1, 2)

Array(3, dtype=int32, weak_type=True)