# How Jax transforms work

In [1]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
    global_list.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }


In [2]:
def log2_with_print(x):
    print("printed x:", x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }


In [3]:
def log2_if_rank_2(x):
    if x.ndim == 2:
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2
    else:
        return x
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda  ; a.
  let 
  in (a,) }


# JIT compiling a function

In [4]:
import jax 
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

3.02 ms ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

31.5 µs ± 357 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Why can't we just JIT everything?

In [6]:
# Condition on value of x.

def f(x):
    if x > 0:
        return x
    else:
        return 2 * x
    
f_jit = jax.jit(f)
f_jit(10) # Should raise an error

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at <ipython-input-6-ab3aead2e7bf>:3, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to f at <ipython-input-6-ab3aead2e7bf>:3, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)

In [None]:
# While loop conditioned on x and n.

def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

g_jit = jax.jit(g)
g_jit(10, 20) # Should raise an error

In [None]:
# While loop conditioned on x and n with a jitted body

@jax.jit
def loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i

g_inner_jitted(10, 20)

In [None]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

In [None]:
g_jit_correct = jax.jit(g, static_argnums=1)
print(g_jit_correct(10, 20))

# When to use JIT

In [None]:
print("g jitted:")
%timeit g_jit_correct(10, 20).block_until_ready()

print("g:")
%timeit g(10, 20)

# Caching

In [None]:
def unjitted_loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted_poorly(x, n):
    i = 0
    while i < n:
        # Don't do this!
        i = jax.jit(unjitted_loop_body)(i)
    return x + i
print("jit called outside the loop:")
%timeit g_inner_jitted(10, 20).block_until_ready()

print("jit called insed the loop:")
%timeit g_inner_jitted_poorly(10, 20).block_until_ready()