<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Jax_tutorial_(2)_Just_In_Time_Compilation_with_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## How JAX Transfomrs work?
 - first converting the Python function into a simple intermediate language called `jaxpr`
 - transformation worked on that jaxpr representation

In [1]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


impure functions can still be written and even run, but JAX gives no guarantees about their behavior once converted to jaxpr

In [2]:
def log2_with_print(x):
  print("printed x:", x) # printed x is a Traced object
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.0))

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


If we have a conditional, jaxpr will only know about the branch we take:

In [3]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

In [4]:
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


## JIT compiling a function

In [5]:
## code for computing a Scaled Exponential Linaer Unit (SELU)
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1_000_000)
%timeit selu(x).block_until_ready()

4.75 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Code above is sending one operation at a time to the accelerator, limiting the ability of XLA compiler to optimize our fcns.
 - Jax provides `jax.jit` transformation, which will JIT compile a JAX-compatible function.

In [6]:
selu_jit = jax.jit(selu)

# warm up - JAX does its tracing
# jaxpr is then compiled using XLA into very efficient code optimized for GPU / TPU
# compiled code is executed to satisfy the call
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

1.42 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Why can't we JIT everything?

In [7]:
def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

f_jit = jax.jit(f)
f_jit(10)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at <ipython-input-7-11fafa505842>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [8]:
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g_jit = jax.jit(g)
g_jit(10, 20)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function g at <ipython-input-8-ec75af77164d>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [9]:
@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

In [10]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

10


In [11]:
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))

30


In [13]:
from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))

30


## When to use JIT

`jax.jit` introduces some overhead itself. It usually only saves time if the compiled function is complex and you will run it numerous times.

This is common in ML, where we tend to compile a large model and run it for millitions of iterations.

In [14]:
print("g jitted: ")
%timeit g_jit_correct(10, 20).block_until_ready()

print("g: ")
%timeit g(10, 20)

g jitted: 
7.66 µs ± 919 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
g: 
2.21 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## Caching Behaviour of Jax.JIT

`f = jax.jit(g)`
When I first invoke `f`, it'll get compiled and the resulting XLA code will get cached. Subsequent calls of `f` will use the cached code.

If I specify `static_argnums`, then the cached code will be used **only** for the same values of arguments labelled as static.

Avoid calling `jax.ji` inside loops.

In [15]:
from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! Each time partial returns a function with diff hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this! Lambda will also return a function with a diff hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is okay, since JAX can find the cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print('jit called in a loop with partials')
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print('jit called in a loop with lamba')
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print('jit called in a loop with caching')
%timeit g_inner_jitted_normal(10, 20).block_until_ready()


jit called in a loop with partials
376 ms ± 5.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lamba
374 ms ± 6.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching
3.46 ms ± 753 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
