![Py4Eng](img/logo.png)

# JAX: automatic differentiation, just-in-time compilation, and acceleration
## Yoav Ram

[JAX](http://jax.readthedocs.io/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

JAX combines automatic differentiation and a machine-learning specific compiler ([XLA](https://www.tensorflow.org/xla)) for high-performance numerical computing.

Benefits of JAX:
- a familiar NumPy-style API for ease of adoption,
- composable function transformations for **compilation**, **batching/vectorization**, **automatic differentiation**, and **parallelization**,
- The same code executes on multiple backend accelerators, including CPU, GPU, and TPU (Google's GPU)

JAX allows us to just-in-time compile functions and importantly to compute gradients automatically. We will see these features as we proceed.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import jax
import jax.numpy as np
print('JAX', jax.__version__, "on", jax.default_backend())

JAX 0.5.2 on cpu


# Example: Learning XOR with neural network

As an example, we will build a neural network that learns the Exclusive OR (XOR) function, following a [tutorial by Collin Raffel](https://colinraffel.com/blog/you-don-t-know-jax.html).

Reminder: XOR takes two binary inputs and returns a single binary output so that the output is 0 if the inputs are equal (1 and 1 or 0 and 0) and 1 if the inputs are different (0 and 1 or 1 and 0).

We'll use a small neural network so that for the input $x$
$$
z = \tanh(w_1 x + b_1) $$$$
y = \sigma(w_2 z + b_2)
$$
where the output is $y$ and the hidden layer is $z$. 
The activation functions are the hyperbolic tangent `tanh` in the first layer, and the expit/sigmoid/logistic $\sigma(x)=\frac{1}{1+e^{-x}}$

The size of $x$ is 2, the size of $y$ is 1, and we set the size of $z$ to be 3.
Therefore, the parameters are $w_1$ of size 2x3, $b_1$ of size 3 , $w_2$ of size 3x1, and $b_2$ of size 1.

In [10]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def nn(params, x):
    w1, b1, w2, b2 = params
    z = np.tanh(x @ w1 + b1)
    return sigmoid(z @ w2 + b2)

We need to initilizae the parameters from a random distribution. 

JAX uses a stateless pseudorandom number generator (PRNG), because JAX is stateless to support parallelizaiton, compilation, and differentiation. Therefore, we have to keep track of the state of the PRNG, which is called **key**. We do so by producing a new key from an integer of our choice with `jax.random.key(k)`, and by splitting an existing key to `n` keys with `jax.random.split(k, n)`. The PRNG can then consume a key to create arrays of random numbers.

Note: Given a value of `k`, the sequence of random numbers is determined and replicable.


In [11]:
def init_params(key):
    subkeys = jax.random.split(key, 4)
    w1 = jax.random.normal(subkeys[0], shape=(2, 3))
    b1 = jax.random.normal(subkeys[1], shape=(3,))
    w2 = jax.random.normal(subkeys[2], shape=(3,))
    b2 = jax.random.normal(subkeys[3], shape=(1,))
    return w1, b1, w2, b2

key = jax.random.key(9)
subkey, key = jax.random.split(key)
params = init_params(subkey)

We train with the neural network by minimizing the cross entropy loss function (i.e., the negative log likelihood of the Bernoulli distribution) via stochastic gradient descent.

In [12]:
def loss(params, X, Y, ϵ=1e-10):
    Yhat = nn(params, X)
    return -(Y * np.log(Yhat + ϵ) + (1 - Y) * np.log(1 - Yhat + ϵ)).mean()

The gradient of the loss function with respect to the network parameters can be automatically differentiated with `jax.grad`, which takes a function `f` and returns a new function `df` which computes the gradient of the original function with respect to a specific argument (be default, the first argument).

To use gradient descent, we want to be able to compute the gradient of our loss function with respect to our neural network's parameters. 

In [13]:
loss_grad = jax.grad(loss)

That was easy!

Now we implement the gradient descent algorithm for minimization of the loss function, in which we update the parameters using the gradient descent rule,
$$
w = w - \eta \cdot \frac{d L}{d w}(w, X, Y)
$$
where $w$ is a parameter, $\eta$ is called _learning rate_, and $L$ is the loss function.

Usually, gradient descent is repeated until some threshold is met or for a set amount of iterations. Here, we will continue until the network learned all four combinations.
So here we write a `test` function that checks if the network has learned everything we want.

In [14]:
# generate the data
X = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = np.bitwise_xor(X[:,0], X[:, 1])

# test function
def test(params, X, Y):
    Yhat = (nn(params, X) > 0.5).astype(int)
    return (Yhat == Y).all()

test(params, X, Y)

Array(False, dtype=bool)

Let the training begin.

In [15]:
key = jax.random.key(9)
subkey, key = jax.random.split(key)
params = init_params(subkey)
print("Loss:", loss(params, X, Y))
print("Test:", test(params, X, Y))
η = 1.0

Loss: 0.7860451
Test: False


In [16]:
%%time
t = 0
while True:
    t += 1
    grads = loss_grad(params, X, Y)
    params = [
        p - η * g 
        for p, g in zip(params, grads)
    ]
    if t % 10 == 0:
        print("{:d}: {:.5f}".format(t, loss(params, X, Y)))
        if test(params, X, Y):
            break
            

10: 0.69145
20: 0.68096
30: 0.65436
40: 0.60349
50: 0.54253
60: 0.49203
70: 0.45673
80: 0.43273
90: 0.41603
100: 0.40400
110: 0.39503
120: 0.38810
130: 0.38259
140: 0.37808
150: 0.37426
160: 0.37089
170: 0.36778
180: 0.36465
190: 0.36108
200: 0.35601
210: 0.34557
CPU times: user 1.58 s, sys: 49.2 ms, total: 1.63 s
Wall time: 1.63 s


In [17]:
print((nn(params, X) > 0.5).astype(int))
print(Y)

[0 1 1 0]
[0 1 1 0]


So we have trained a neural network on the XOR function succesfully, and rather fast. But we can make it even faster if we jit compile the computational intensive part. First, we refactor the code so that the gradient computation and the parameter updating is in a single separate function, and then we compile that function.

In [92]:
def gradient_descent(params, X, Y):
    grads = loss_grad(params, X, Y)
    params = [
        p - η * g 
        for p, g in zip(params, grads)
    ]
    return params

%timeit gradient_descent(params, X, Y)
gradient_descent = jax.jit(gradient_descent)
gradient_descent(params, X, Y) # burn-in
%timeit gradient_descent(params, X, Y)

5.06 ms ± 160 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 μs ± 189 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


That's a significant improvement of about 400-fold.

Now run the training again.

In [93]:
key = jax.random.key(9)
subkey, key = jax.random.split(key)
params = init_params(subkey)
print("Loss:", loss(params, X, Y))
print("Test:", test(params, X, Y))

Loss: 0.7860451
Test: False


In [94]:
%%time
η = 1.0
t = 0
while True:
    t += 1
    params = gradient_descent(params, X, Y)
    if t % 10 == 0:
        print("{:d}: {:.5f}".format(t, loss(params, X, Y)))
        if test(params, X, Y):
            break

10: 0.69145
20: 0.68096
30: 0.65436
40: 0.60349
50: 0.54253
60: 0.49203
70: 0.45673
80: 0.43273
90: 0.41603
100: 0.40400
110: 0.39503
120: 0.38810
130: 0.38259
140: 0.37808
150: 0.37426
160: 0.37089
170: 0.36778
180: 0.36465
190: 0.36108
200: 0.35601
210: 0.34557
CPU times: user 134 ms, sys: 3.51 ms, total: 138 ms
Wall time: 76.1 ms


In [95]:
print((nn(params, X) > 0.5).astype(int))
print(Y)

[0 1 1 0]
[0 1 1 0]


So overall training time was reduced from 1.2 s to 76 ms, a 15-fold improvement. With larger datasets and neural networks, we can use a multicore CPU, GPU, or TPU with JAX without modifying the code to get even better improvements.

# Differences from standard NumPy

When using JAX we can mostly use the NumPy API, with specific differences:
- JAX arrays are immutable so we cannot use item assignment (`arr[i] = x` is not allowed)
- random number generation requires us to provide a random key at every call because the random number generator is stateless.
- jit-compiling functions with branching conrol flow (`if` etc.) is problematic, as a single run of the function probably doesn't go through all the relevant paths; this can be solved by using `jax.lax.cond` etc. (see more [in the docs](https://docs.jax.dev/en/latest/control-flow.html#control-flow)).

A more comprehensive list is available at the JAX docs under ["The Sharp Bits"](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).

## Changing an array

Since JAX arrays are immutable, we cannot set to an array as we would in NumPy, that is, we cannot do
```python
arr[i] = 5
```
Instead, we can use a JAX equivalent
```python
arr = arr.at[i].set(5)
```
This creates a new array with 5 at index `i` and sets the new array to the variable name `arr`.
However, during compilation, the compiler can decide to apply the item assignment in place instead of creating a new array.

Here's an example in a function that generates the first 1000 Bernoulli numbers, defined by
$$
a_i = a_{i-1} + a_{i-2}, \quad
a_0 = 1, \quad
a_1 = 2
$$

In [96]:
def bernoulli():
    n = 1000
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

%timeit bernoulli();

266 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Compiling this, we can a huge improvement:

In [97]:
bernoulli = jax.jit(bernoulli)
bernoulli();
%timeit bernoulli();

3.4 μs ± 77.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Static arguments

If we want the `bernoulli` function to take the value of `n` from the user, we cannot compile it this way.

In [98]:
@jax.jit
def bernoulli(n):
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

bernoulli(10)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function bernoulli at /var/folders/vc/qm7741c57dsg9f7wyrtrdrrm0000gq/T/ipykernel_60395/1599531239.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.

This occurs because the shape of an array (in `np.zeros(n)`) cannot be set to a traced value, which is what the arguments of a function are, by default. So we can set `n` to be a static argument. This means, however, that when the function is compiled, it is compiled for a specific *value* of `n` rather than a specific *type* of `n`. So in this case, it defeats the purpose, as we are not going to call the function on the same value of `n` more than once.

In [99]:
def bernoulli(n):
    a = np.zeros(n)
    a = a.at[0].set(1)
    a = a.at[1].set(2)
    for i in range(2, n):
        a = a.at[i].set(a[i-1] + a[i-2])
    return a

bernoulli = jax.jit(bernoulli, static_argnames='n')
bernoulli(1000);

In [100]:
%timeit bernoulli(1000);

3.47 μs ± 27.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# Mapping

JAX allows us to vectorize a function by mapping it over input dimensions.

Consider the following function that computes the Hamming distance between two 1D arrays (i.e., the number of elements in which the two arrays differ).

In [19]:
def hamming1D(x, y):
    return (x!=y).sum()

x = np.array([1, 2, 3, 4, 5])
y = np.array([5, 4, 3, 2, 1])
hamming1D(x, y)

Array(4, dtype=int32)

What if we want to compute the distance between a single array `x` and all the rows of a matrix `Y`? We would need a for loop over the rows.

With JAX, we can use `jax.vmap`. We use the argument `in_axes` to specify the axes/dimensions of each argument on which we want to map. Here, we don't want to map over `x`, so we put `None`, and we want to map over the 1st dimension of `Y`, so we put `0`.

In [106]:
Y = jax.random.randint(jax.random.key(4), shape=(10, 5), minval=1, maxval=5)
hamming2D = jax.jit(jax.vmap(hamming1D, in_axes=[None, 0]))
hamming2D(x, Y)

Array([5, 5, 4, 4, 5, 5, 5, 4, 5, 2], dtype=int32)

# References
- [You don't know JAX by Colin Raffel](https://colinraffel.com/blog/you-don-t-know-jax.html)
- [Comparison of NumPy, Numba, JAX, and C on Mandelbrot's fractal](https://gist.github.com/jpivarski/da343abd8024834ee8c5aaba691aafc7)

# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com).

This work is licensed under a CC BY-NC-SA 4.0 International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)