# JAX

Website: https://jax.readthedocs.io/en/latest/index.html

Main idea: the bridge between python and XLA

1. Python code -> `jaxpr` language 
2. `jaxpr` language -> XLA compiled code

In [1]:
import jax
import jax.numpy as jnp

## `jaxpr` language

`jaxprs` are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).

In [2]:
def log2(x):
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


## XLA: Accelerated Linear Algebra

XLA is an open-source compiler for machine learning. The XLA compiler takes models from popular frameworks such as PyTorch, TensorFlow, and **JAX**, and optimizes the models for high-performance execution across different hardware platforms including **GPUs, CPUs, and ML accelerators**.

## 1. `jax.jit()` Just-in-time compilation

### Example: SELU

$$
{\rm SELU}(x) = \lambda({\rm max}(0,x) + {\rm min}(0,\alpha (e^x-1)))
$$

In [3]:
# Without jit

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

1.81 ms ± 62.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
# With jit

selu_jit = jax.jit(selu)
%timeit selu_jit(x).block_until_ready()

386 μs ± 67.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## 2. `jax.vmap()` Automatic vectorization

In [5]:
def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

x = jnp.array([0. ,1. ,2. ,3. ,4.])
w = jnp.array([2., 3., 4.])
convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [6]:
auto_batch_convolve = jax.vmap(convolve)

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
auto_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [7]:
jax.make_jaxpr(convolve)(x, w)

{ lambda ; a:f32[5] b:f32[3]. let
    c:f32[3] = slice[limit_indices=(3,) start_indices=(0,) strides=None] a
    d:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] c b
    e:f32[3] = slice[limit_indices=(4,) start_indices=(1,) strides=None] a
    f:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] e b
    g:f32[3] = slice[limit_indices=(5,) start_indices=(2,) strides=None] a
    h:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] g b
    i:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] d
    j:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
    k:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
    l:f32[3] = concatenate[dimension=0] i j k
  in (l,) }

In [8]:
jax.make_jaxpr(auto_batch_convolve)(xs, ws)

{ lambda ; a:f32[2,5] b:f32[2,3]. let
    c:f32[2,3] = slice[limit_indices=(2, 3) start_indices=(0, 0) strides=None] a
    d:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] c b
    e:f32[2,3] = slice[limit_indices=(2, 4) start_indices=(0, 1) strides=None] a
    f:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] e b
    g:f32[2,3] = slice[limit_indices=(2, 5) start_indices=(0, 2) strides=None] a
    h:f32[2] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] g b
    i:f32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] d
    j:f32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] f
    k:f32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] h
    l:f32[2,3] = concatenate[dimension=1] i j k
  in (l,) }

## 3. `jax.grad()` Automatic differentiation

Chain rule: 

$$
\frac{df(g(x))}{dx} = \frac{df}{dg}\frac{dg}{dx}
$$

In [9]:
dlog2dx = jax.grad(log2)

jax.make_jaxpr(dlog2dx)(3.0)

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    _:f32[] = div b c
    d:f32[] = div 1.0 c
    e:f32[] = div d a
  in (e,) }

## Astro-related implications of JAX

### MCMC sampling

Example: https://rlouf.github.io/post/jax-random-walk-metropolis/

![mcmc](mcmc.png)

### Spectral fitting

Example: `jaxspec` arXiv:2409.05757

![jaxspec](jaxspec.png)

### Stellar streams

Example: `StreamSculpter` arXiv:2410.21174

![stream](stream.png)