#### JIT in JAX
1. JAX uses XLA for compilation to transform your normal python functions so that they can be executed **more efficiently** in XLA
2. In the following example, two functions will be called twice. For both, the first time is taking much longer. This is because JAX is doing **`tracing`** on the first call. Depending on the inputs, the tracers convert the code into an intermediate language, **`jaxprs`**. **`jaxprs`** is then compiled by XLA. The subsequent calls will just use the compiled code

In [14]:
import os
import time
import requests

import jax
import jax.numpy as jnp
from jax import jit, grad, random

from jax.config import config

def apply_activation(x):
    print("Compile")
    return jnp.maximum(0.0, x)

def get_dot_product(W, X):
    print("Compile")
    return jnp.dot(W, X)

# Always use a seed
key = random.PRNGKey(1234)
W = random.normal(key=key, shape=[100, 1000], dtype=jnp.float32)

# Never reuse the key
key, subkey = random.split(key)
X = random.normal(key=subkey, shape=[1000, 2000], dtype=jnp.float32)

# JIT the functions we have
dot_product_jit  = jit(get_dot_product)
activation_jit = jit(apply_activation)

%time Z = dot_product_jit(W, X).block_until_ready()
%time A = activation_jit(Z).block_until_ready()

%time Z = dot_product_jit(W, X).block_until_ready()
%time A = activation_jit(Z).block_until_ready()

Compile
CPU times: user 48.9 ms, sys: 630 µs, total: 49.5 ms
Wall time: 58.7 ms
Compile
CPU times: user 66.3 ms, sys: 0 ns, total: 66.3 ms
Wall time: 96.6 ms
CPU times: user 25.2 ms, sys: 0 ns, total: 25.2 ms
Wall time: 13.5 ms
CPU times: user 1.06 ms, sys: 0 ns, total: 1.06 ms
Wall time: 917 µs


##### Jaxprs
1. A **`jaxprs`** instance represents a function with one or more typed parameters (input variables) and one ore more typed results
2. The inputs/outputs have **`types`** and are represented as abstract values
3. Not all python programs can be represented by **`jaxprs`**  but many scientific computations can be

In [15]:
# Make jaxpr for the activation function
print(jax.make_jaxpr(activation_jit)(Z))

Compile
{ lambda ; a:f32[100,2000]. let
    b:f32[100,2000] = xla_call[
      call_jaxpr={ lambda ; c:f32[100,2000]. let
          d:f32[100,2000] = max 0.0 c
        in (d,) }
      name=apply_activation
    ] a
  in (b,) }


##### Printing
1. Nothing stops you from running an impure function. Below is an interesting example:

In [16]:
def number_squared(num):
    print("Received: ", num)
    return num ** 2

# Compiled version
number_squared_jit = jit(number_squared)

# Make jaxprs
print(jax.make_jaxpr(number_squared_jit)(2))

Received:  Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>
{ lambda ; a:i32[]. let
    b:i32[] = xla_call[
      call_jaxpr={ lambda ; c:i32[]. let d:i32[] = integer_pow[y=2] c in (d,) }
      name=number_squared
    ] a
  in (b,) }


2. Notice that the variable *num* is traced since the first call will run the python code. The subsequent **`jaxprs`** remains the same as before.

In [17]:
for i, num in enumerate([2,4,8]):
    print("Iteration: ", i+1)
    print("Result: ", number_squared_jit(num))
    print("="*50)

Iteration:  1
Received:  Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Result:  4
Iteration:  2
Result:  16
Iteration:  3
Result:  64


##### JIT and Python Control Flow
1. There are several situations where JIT would fail. One of the most common one is the dependence on the variable. This usually manifests in control flows conditioned on the value of the input.

In [18]:
def square_or_cube(x):
    if x % 2 == 0:
        return x ** 2
    else:
        return x * x * x

# JIT transformation
square_or_cube_jit = jit(square_or_cube) # no error here

# Run the jitted function
try:
    val = square_or_cube_jit(2)
except Exception as ex:
    print(type(ex).__name__, ex)


ConcretizationTypeError Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function square_or_cube at /tmp/ipykernel_188426/2554669772.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


2. Lets go through what's happening here. When we jit a function, we aim to get a compiled version of that function, so that we can cache and reuse the compiled code for different values. To achieve this, JAX traces it on abstract values that represent sets of possible inputs
3. There are different levels of abstraction that are used during tracing, and the kind of abstraction used for a particular function tracing depends on the kind of transformation done. By default, jit traces code on the **`ShapedArray`** abstraction level.
4. For example, if we trace using the abstract value **`ShapedArray((3,), jnp.float32)`**, we get a view of the function that can be reused for any concrete value that is of the same shape