# JAX
[jit](https://docs.jax.dev/en/latest/jit-compilation.html)

python framework for mathmatic computation programming.

As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code. (**JIT compiling a function**)

In [1]:
import jax
import jax.numpy as jnp

In [2]:
global_list = []

def log2(x):
    global_list.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
    

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0:f32[]
    d[35m:f32[][39m = div b c
  [34;1min [39;22m(d,) }


In [3]:
global_list

[JitTracer<~float32[]>]

Self-Normalizing Neural Networks

computing a Scaled Exponential Linear Unit (SELU)

In [4]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()



11.9 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

263 μs ± 142 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
