# JIT in JAX

Suppose that we have a function that is giga-computational demanding.

```python
def expensive_func(x):
    # a bunch of giga-heavy computations
    # Python is slow, and Numpy too
    return ...
```

JAX can make the function above faster by compiling it into machine codes by using the JIT compilation. We trigger the **compilation** by calling `jax.jit`

```python
import jax
jitted_func = jax.jit(expensive_func) 
# The return jitted_func is a function that has exactly the same signature as my_expensive_func
```

It is equivalent to put a decorator over the function definition (like `tf.function()`):

```python
@jax.jit
def expensive_func(x):
    ...
    return ...
```

# Example

In [None]:
import numpy as np
import jax
import jax.numpy as jnp

Recall that Jax by default uses `float32` for the best compatiblity with GPUs. You need to manually do the following to enable `float64` globally.

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
def func_np(x):
    return np.diff(np.diff(np.diff(np.diff(np.diff(x)))))

@jax.jit
def func_jax(x):
    return jnp.diff(jnp.diff(jnp.diff(jnp.diff(jnp.diff(x)))))

x = np.random.randn(100000)
jx = jnp.asarray(x)

In [None]:
%timeit func_np(x)

In [None]:
# Trigger jit
func_jax(jx)

%timeit func_jax(jx).block_until_ready()

Around 4 times faster.

Numpy uses OpenBLAS/MKL/BLIS... as backend, why it is still slow?

A more illustrative example which shows that the overhead of Numpy is problematic:

In [None]:
A, B = np.eye(50), np.eye(50)
def func_np(x):
    for i in range(100):
        x = B @ x + np.linalg.solve(A, x) + np.linalg.norm(x)
    return x
        
        
Aj, Bj = jnp.eye(50), jnp.eye(50)
@jax.jit
def func_jax(x):
    def scan_body(carry, _):
        x = carry
        return Bj @ x + jnp.linalg.solve(Aj, x) + jnp.linalg.norm(x), _
    return jax.lax.scan(scan_body, x, jnp.arange(100))[-1]

In [None]:
%timeit func_np(np.ones((50, )))

In [None]:
# Trigger jit
func_jax(jnp.ones((50, )))

%timeit func_jax(jnp.ones((50, ))).block_until_ready()

# What happened inside JIT?

1. When Python executes `func_jax` for the first time, the function `jax.jit` traces/traverses all the operations inside the function `func_jax`.

2. Then JIT compiles these operations into the **accelerated linear algebra (XLA)** codes. Imagine this as that of compiling C codes to an executable file. No numerical computations are done!

3. Then, by the next time you call the jitted `func_jax`, Python will execute the compiled XLA codes to carry out the numerical computations. 

4. After the numerical computations are done in the machine level, the results are sent back to Python.

JAX basically uses Python as a "metaprogramming language" that specifies how to build an XLA program (quote: [Patrick Kidger](https://kidger.site/thoughts/jax-vs-julia)). 

# When/where to JIT?

Usually, we JIT

1. the part that has the largest scope, so that the compiler can understand your programme better,
2. or the function(s) that are called repetitively, for instance, the objective function in optimisation:

```python
@jax.jit
def objective_func(params):
    ...
    return ...
```

Remember, when we write the jax code, we are describing a computation flow.

# Will these work?

In [None]:
@jax.jit
def my_func(x):
    return np.exp(x)

my_func(jnp.ones((2, )))

In [None]:
@jax.jit
def my_func(x):
    return x + np.array([1., 2.])

my_func(jnp.ones((2, )))
my_func(np.ones((2, )))

Note: this does not work in some early jax versions.

In [None]:
# Another example

class MyClass:

    def __init__(self):
        self.y = jnp.array(1.)

    @jax.jit
    def my_method(self, x):
        return x + self.y
    
obj = MyClass()
obj.my_method(jnp.array(2.))

We can force it to work by adding a `static_argums` option. This assumes that `self` is static, that is, immutable.

In [None]:
from functools import partial

class MyClass:

    def __init__(self):
        self.y = jnp.array(1.)

    @partial(jax.jit, static_argnums=(0, ))
    def my_method(self, x):
        return x + self.y
    
obj = MyClass()
obj.my_method(jnp.array(2.))

Similarly, this will not (immediately) work either

```python
@jit
def f(x, g: Callable):
    return x + g(x)
```

We have to make it clear that the argument `g` is static:

```python
@partial(jit, static_argnums=(1, ))
def f(x, g: Callable):
    return x + g(x)
```

A more concise way is to put `g` in an outer scope

```python
g = ... # Definition of g from outer scope

@jit
def f(x):
    return x + g(x)
```

Note: whenever the static argument changes, it will trigger the JIT compilation again to create another XLA code.

JAX accepts **immutable** objects, while Python Class is **mutable**. 

For instance, 

In [None]:
x = jnp.ones((2, ))
x[0] = 1.

# does not work.

Well, doesn't JAX suck if we cannot even assign/update values to variables?

No, this in my opinion a feature not a problem. Immutable objects are best for computations. For projects that need mutable objects, we can rewrite them into that of based on immutable ones.