![Py4Eng](img/logo.png)

# JAX: automatic differentiation, just-in-time compilation, and acceleration
## Yoav Ram

[JAX](http://jax.readthedocs.io/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

JAX combines automatic differentiation and a machine-learning specific compiler ([XLA](https://www.tensorflow.org/xla)) for high-performance numerical computing.

Benefits of JAX:
- a familiar NumPy-style API for ease of adoption,
- composable function transformations for **compilation**, **batching/vectorization**, **automatic differentiation**, and **parallelization**,
- The same code executes on multiple backend accelerators, including CPU, GPU, and TPU (Google's GPU)

JAX allows us to just-in-time compile functions and importantly to compute gradients automatically. We will see these features as we proceed.

In [19]:
%matplotlib inline
import matplotlib.pyplot as plt
import jax
import jax.numpy as np
print('JAX', jax.__version__, "on", jax.default_backend())

JAX 0.5.2 on cpu


# Differences from standard NumPy

When using JAX we can mostly use the NumPy API, with specific differences:
- JAX arrays are immutable so we cannot use item assignment (`arr[i] = x` is not allowed)
- random number generations requires us to provide a random key at every call because the random number generator is stateless.|

## Setting elements of an array

Since JAX arrays are immutable, we cannot set to an array as we would in NumPy, that is, we cannot do
```python
arr[i] = 5
```
Instead, we can use a JAX equivalent
```python
arr = arr.at[i].set(5)
```
This creates a new array with 5 at index `i` and sets the new array to the variable name `arr`.
However, during compilation, the compiler can decide to apply the item assignment in place instead of creating a new array.

Here's an example in a function that generates the first 1000 Bernoulli numbers, defined by
$$
a_i = a_{i-1} + a_{i-2}, \quad
a_0 = 1, \quad
a_1 = 2
$$

In [61]:
def bernoulli():
    n = 1000
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

%timeit bernoulli();

243 ms ± 4.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Compiling this, we can a huge improvement:

In [62]:
bernoulli = jax.jit(bernoulli)
bernoulli();
%timeit bernoulli();

3.13 μs ± 30.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# Static arguments

If we want the `bernoulli` function to take the value of `n` from the user, we cannot compile it this way.

In [75]:
@jax.jit
def bernoulli(n):
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

bernoulli(10)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function bernoulli at /var/folders/vc/qm7741c57dsg9f7wyrtrdrrm0000gq/T/ipykernel_53351/1599531239.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.

This occurs because the shape of an array (in `np.zeros(n)`) cannot be set to a traced value, which is what the arguments of a function are, by default. So we can set `n` to be a static argument. This means, however, that when the function is compiled, it is compiled for a specific *value* of `n` rather than a specific *type* of `n`. So in this case, it defeats the purpose, as we are not going to call the function on the same value of `n` more than once.

In [80]:
def bernoulli(n):
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

bernoulli = jax.jit(bernoulli, static_argnames='n')
bernoulli(1000);

In [81]:
%timeit bernoulli(1000);

3.28 μs ± 140 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Random number generation

We use an example from the Numba session: estimating $\pi$ using sampling.

We copy the NumPy implementation, but we need to use JAX's random number generator, which requires a key. It is easy enough, we just create a key using `jax.random.key(k)` where `k` is the integer key. Given a value of `k`, the sequence of random numbers is determined and replicable.

In [49]:
def estimate_π(key, n):
    x, y = jax.random.uniform(key, shape=(2, n))
    accept = x*x + y*y < 1
    return accept.mean()*4

In [50]:
n = 1000000
key = jax.random.key(12)
%timeit estimate_π(key, n)

3.66 ms ± 47 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


This is already faster than NumPy, which took about 10 ms.
In this case, jit-compiling the function will not improve runtime.

# References

- [Comparison on NumPy, Numba, JAX, and C on Mandelbrot's fractal](https://gist.github.com/jpivarski/da343abd8024834ee8c5aaba691aafc7)

# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com).

This work is licensed under a CC BY-NC-SA 4.0 International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)