<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/book1/intro/jax_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Brief introduction to JAX 

murphyk@gmail.com, Last update: 2021-01-06.

[JAX](https://github.com/google/jax) is a  version of NumPy that runs fast on CPU, GPU and TPU, by compiling down to XLA. It also has an excellent automatic differentiation library, extending the earlier [autograd](https://github.com/hips/autograd) package.

The JAX interface is almost identical to NumPy (by design), but with some small differences, and additional features, some of which we explain below. (More detail can be find in the official documentation.)




In [1]:
# Standard Python libraries
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
import numpy as np
np.set_printoptions(precision=3)
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display
%matplotlib inline

import sklearn


In [2]:

# Load JAX
import jax
import jax.numpy as jnp
from jax import random
#import jax.numpy as np
#import numpy as onp # original numpy
from jax import grad, hessian, jit, vmap
from jax import grad, hessian, jacfwd, jacrev, vmap, jit
print("jax version {}".format(jax.__version__))


jax version 0.2.7


# Random number generation

One of the biggest differences from NumPy is the way Jax treates pseudo random number generation (PRNG).
This is because Jax does not maintain any global state, i.e., it is purely functional.
This design "provides reproducible results invariant to compilation boundaries and backends,
while also maximizing performance by enabling vectorized generation and parallelization across random calls"
(to quote [the official page](https://github.com/google/jax#a-brief-tour)).
                              
Thus, whenever we do anything stochastic, we need to give it a fresh RNG key. We can do this by splitting the existing key into pieces. We can do this indefinitely, as shown below.

In [3]:
import jax.random as random

key = random.PRNGKey(0)
print(random.normal(key, shape=(3,)))  # [ 1.81608593 -0.48262325  0.33988902]
print(random.normal(key, shape=(3,)))  # [ 1.81608593 -0.48262325  0.33988902]  ## identical results

# To make a new key, we split the current key into two pieces.
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # [ 1.1378783  -1.22095478 -0.59153646]

# We can continue to split off new pieces from the global key.
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # [-0.06607265  0.16676566  1.17800343]

# We can always use original numpy if we like (although this may interfere with the deterministic behavior of jax)
np.random.seed(42)
print(np.random.randn(3))

[ 1.816 -0.483  0.34 ]
[ 1.816 -0.483  0.34 ]
[ 1.138 -1.221 -0.592]
[-0.066  0.167  1.178]
[ 0.497 -0.138  0.648]


# GPU magic

In [3]:
# Check if GPU is available
!nvidia-smi


Mon Jan  4 22:36:19 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:

# Check if JAX is using GPU
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))

jax backend gpu


Let's see how JAX can speed up things like matrix-matrix multiplication using a GPU.

First the numpy/CPU version.

In [22]:
# Standard CPU

size = 1000
x = np.random.normal(size=(size, size)).astype(np.float32)
print(type(x))
%timeit -o np.dot(x, x.T)


<class 'numpy.ndarray'>
100 loops, best of 3: 17.7 ms per loop


<TimeitResult : 100 loops, best of 3: 17.7 ms per loop>

In [23]:
res = _ # get result of last cell
time_cpu = res.best
print(time_cpu)

0.01767426187999945


Now the GPU version. We added that block_until_ready because JAX uses [asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) by default.


In [24]:
# GPU version
x = jax.random.normal(key, (size, size), dtype=jnp.float32)
print(type(x))
%timeit -o jnp.dot(x, x.T).block_until_ready() 

<class 'jax.interpreters.xla._DeviceArray'>
The slowest run took 7.90 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 719 µs per loop


<TimeitResult : 1000 loops, best of 3: 719 µs per loop>

In [25]:
res = _
time_gpu = res.best
print('GPU time {:0.6f}, CPU time {:0.6f}, speedup {:0.6f}'.format(
    time_gpu, time_cpu, time_cpu/time_gpu))

GPU time 0.000719, CPU time 0.017674, speedup 24.593663


We can move numpy arrays to the GPU for speed. The result will be transferred back to CPU for printing, saving, etc.

In [34]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
print(type(x))
%timeit np.dot(x, x.T)

x = device_put(x)
print(type(x))
%timeit jnp.dot(x, x.T).block_until_ready()

<class 'numpy.ndarray'>
100 loops, best of 3: 18.3 ms per loop
<class 'jax.interpreters.xla._DeviceArray'>
1000 loops, best of 3: 863 µs per loop


# Vmap <a class="anchor" id="vmap"></a>


To illustrate vmap, consider a binary logistic regression model.

In [99]:
def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)

def predict_single(w, x):
    return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product
  
def predict_batch(w, X):
    return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply


D = 2
N = 3

#np.random.state(42)
#w = np.random.randn(D)
#X = np.random.randn(N, D)
#y = np.random.randint(0, 2, N)

w = jax.random.normal(key, shape=(D,))
X = jax.random.normal(key, shape=(N,D))
y = jax.random.choice(key, 2, shape=(N,)) # uniform binary labels

print(X)
print(y)


# We can apply predict_batch to a matrix of data, but we cannot apply predict_single in this way
# because the order of the arguments to np.dot is incorrect.

p1 = predict_batch(w, X)
print(p1)
try:
    p2 = predict_single(w, X)
except:
    print('cannot apply to batch')

[[-1.455  0.973]
 [-0.217  0.691]
 [-1.011  0.401]]
[0 0 1]
[0.938 0.773 0.814]
cannot apply to batch


To avoid having to think about batch shape, it is often easier to write a function that works on single
input vectors. We can then apply this in a loop.

In [101]:
p3 = [predict_single(w, x) for x in X]
assert np.allclose(p1, p3)

Unfortunately, mapping down a list is slow.
Fortunately, JAX provides `vmap`, which has the same effect, but can be parallelized.

We first apply the `predict_single` function to its first arugment, w, to get a function that only
depends on x. We then vectorize this, and map the resulting modified function along rows (dimension 0)
of the data matrix.

In [41]:
from functools import partial

predict_single_w = partial(predict_single, w)
predict_batch_w = vmap(predict_single_w)
p4 = predict_batch_w(X)
assert np.allclose(p1, p4)

# More concise
p5 = vmap(predict_single, in_axes=(None, 0))(w, X)
assert np.allclose(p1, p5)

p6 = vmap(partial(predict_single, w))(X)
assert np.allclose(p1, p6)


# Autograd <a class="anchor" id="AD"></a>

In this section, we illustrate automatic differentation using JAX.



## Simple convex functions

In [42]:
from jax import grad, hessian, jacfwd, jacrev, vmap, jit

Linear function: multi-input, scalar output.

$$
\begin{align}
f(x; a) &= a^T x\\
\nabla_x f(x;a) &= a
\end{align}
$$

In [45]:
# We construct a single output linear function.
# In this case, the Jacobian and gradient are the same.
def fun1d(x):
    return jnp.dot(a, x)[0]

Din = 3; Dout = 1;
a = jax.random.normal(key, shape=(Dout, Din))
x = jax.random.normal(key, shape=(Din,))
g = grad(fun1d)(x)
assert np.allclose(g, a)
J = jacrev(fun1d)(x)
assert np.allclose(J, g)

Linear function: multi-input, multi-output.

$$
\begin{align}
f(x;A) &= A x \\
\nabla_x f(x;A) &= A
\end{align}
$$

In [47]:
# We construct a multi-output linear function.
# We check forward and reverse mode give same Jacobians.


def fun(x):
    return jnp.dot(A, x)

Din = 3; Dout = 4;
A = jax.random.normal(key, shape=(Dout, Din))
x = jax.random.normal(key, shape=(Din,))
Jf = jacfwd(fun)(x)
Jr = jacrev(fun)(x)
assert np.allclose(Jf, Jr)
assert np.allclose(Jf, A)

Quadratic form.

$$
\begin{align}
f(x;A) &= x^T A x \\
\nabla_x f(x;A) &= (A+A^T) x \\
\nabla^2 x^2 f(x;A) &= A + A^T
\end{align}
$$

In [49]:

D = 4
A = jax.random.normal(key, shape=(D,D))
x = jax.random.normal(key, shape=(D,))

quadfun = lambda x: jnp.dot(x, jnp.dot(A, x))

J = jacfwd(quadfun)(x)
assert np.allclose(J, jnp.dot(A+A.T, x))

H1 = hessian(quadfun)(x)
assert np.allclose(H1, A+A.T)

def my_hessian(fun):
  return jacfwd(jacrev(fun))

H2 = my_hessian(quadfun)(x)
assert np.allclose(H1, H2)

Chain rule applied to sigmoid function.

$$
\begin{align}
\mu(x;w) &=\sigma(w^T x) \\
\nabla_w \mu(x;w) &= \sigma'(w^T x) x \\
\sigma'(a) &= \sigma(a) * (1-\sigma(a)) 
\end{align}
$$

In [57]:


D = 4
w = jax.random.normal(key, shape=(D,))
x = jax.random.normal(key, shape=(D,))
y = 0 

def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)
def mu(w): return sigmoid(jnp.dot(w,x))
def deriv_mu(w): return mu(w) * (1-mu(w)) * x
deriv_mu_jax =  grad(mu)

print(deriv_mu(w))
print(deriv_mu_jax(w))

assert np.allclose(deriv_mu(w), deriv_mu_jax(w), atol=1e-3)



[0.002 0.003 0.004 0.002]
[0.002 0.003 0.004 0.002]


## Binary logistic regression

In [58]:

# negative log likelihood
def loss(weights, inputs, targets):
    preds = predict_batch(weights, inputs)
    logprobs = jnp.log(preds) * targets + jnp.log(1 - preds) * (1 - targets)
    return -jnp.sum(logprobs)


D = 2
N = 3
w = jax.random.normal(key, shape=(D,))
X = jax.random.normal(key, shape=(N,D))
y = jax.random.choice(key, 2, shape=(N,)) # uniform binary labels

print(loss(w, X, y))

# Gradient function
grad_fun = grad(loss)

# Gradient of each example in the batch - 2 different ways
grad_fun_w = partial(grad_fun, w)
grads = vmap(grad_fun_w)(X,y)
print(grads)
assert grads.shape == (N,D)

grads2 = vmap(grad_fun, in_axes=(None, 0, 0))(w, X, y) 
assert np.allclose(grads, grads2)

# Gradient for entire batch
grad_sum = jnp.sum(grads, axis=0)
assert grad_sum.shape == (D,)
print(grad_sum)

4.468591
[[-1.365  0.913]
 [-0.168  0.534]
 [ 0.188 -0.075]]
[-1.345  1.373]


In [59]:
# Textbook implementation of gradient
def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_batch(weights, X)
    g = jnp.sum(jnp.dot(jnp.diag(mu - y), X), axis=0)
    return g

grad_sum_batch = NLL_grad(w, (X,y))
print(grad_sum_batch)
assert np.allclose(grad_sum, grad_sum_batch)

[-1.345  1.373]


In [60]:
# We can also compute Hessians, as we illustrate below.
from jax import hessian

hessian_fun = hessian(loss)

# Hessian on one example
H0 = hessian_fun(w, X[0,:], y[0])
print('Hessian(example 0)\n{}'.format(H0))

# Hessian for batch
Hbatch = vmap(hessian_fun, in_axes=(None, 0, 0))(w, X, y) 
print('Hbatch shape {}'.format(Hbatch.shape))

Hbatch_sum = jnp.sum(Hbatch, axis=0)
print('Hbatch sum\n {}'.format(Hbatch_sum))

Hessian(example 0)
[[ 0.123 -0.082]
 [-0.082  0.055]]
Hbatch shape (3, 2, 2)
Hbatch sum
 [[ 0.286 -0.17 ]
 [-0.17   0.163]]


In [62]:
# Textbook implementation of Hessian

def NLL_hessian(weights, batch):
  X, y = batch
  mu = predict_batch(weights, X)
  S = jnp.diag(mu * (1-mu))
  H = jnp.dot(jnp.dot(X.T, S), X)
  return H

H2 = NLL_hessian(w, (X,y) )
assert np.allclose(Hbatch_sum, H2, atol=1e-2)

# Vector Jacobian Products

Consider a bilinear mapping $f(x,W) = x W$.
For fixed parameters, we have
$f1(x) = W x$, so $J(x) = W$, and $u^T J(x) = J(x)^T u = W^T u$.


In [27]:
n = 3; m = 2;
W = jax.random.normal(key, shape=(m,n))
x = jax.random.normal(key, shape=(n,))
u = jax.random.normal(key, shape=(m,))

def f1(x): return jnp.dot(W,x)

J1 = jacfwd(f1)(x)
print(J1.shape)

assert np.allclose(J1, W)
tmp1 = jnp.dot(u.T, J1)
print(tmp1)

(val, jvp_fun) = jax.vjp(f1, x)
tmp2 = jvp_fun(u)
assert np.allclose(tmp1, tmp2)

tmp3 = np.dot(W.T, u)
assert np.allclose(tmp1, tmp3)




(2, 3)
[ 0.922  1.216 -0.61 ]


For fixed inputs, we have
$f2(W) = W x$, so $J(W) = \text{something complex}$,
but $u^T J(W) = J(W)^T u = u x^T$.

In [30]:

def f2(W): return jnp.dot(W,x)

J2 = jacfwd(f2)(W)
print(J2.shape)

tmp1 = jnp.dot(u.T, J2)
print(tmp1)
print(tmp1.shape)

(val, jvp_fun) = jax.vjp(f2, W)
tmp2 = jvp_fun(u)
assert np.allclose(tmp1, tmp2)

tmp3 = np.outer(u, x)
assert np.allclose(tmp1, tmp3)


(2, 2, 3)
[[-1.425  0.379 -0.267]
 [ 1.555 -0.413  0.291]]
(2, 3)



# JIT (just in time compilation) <a class="anchor" id="JIT"></a>

In this section, we illustrate how to use the Jax JIT compiler to make code go faster (even on a CPU). However, it does not work on arbitrary Python code, as we explain below.




In [63]:
grad_fun_jit = jit(grad_fun) # speedup gradient function
grads_jit = vmap(partial(grad_fun_jit, w))(X,y)
assert np.allclose(grads, grads_jit)


In [64]:
# We can apply JIT to non ML applications as well.

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
%timeit  slow_f(x) 

fast_f = jit(slow_f)
%timeit fast_f(x)  
 
assert np.allclose(slow_f(x), fast_f(x))

100 loops, best of 3: 5 ms per loop
The slowest run took 8.36 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 1.37 ms per loop


We can also add the `@jit` decorator in front of a function.



In [65]:
@jit
def faster_f(x):
  return x * x + x * 2.0
%timeit faster_f(x)
assert np.allclose(faster_f(x), fast_f(x))  

The slowest run took 12.84 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 1.37 ms per loop


## Static argnum

Note that JIT compilation requires that the control flow through the function  can be determined by the shape (but not concrete value) of its inputs. The function below violates this, since when x<3, it takes one branch, whereas when x>= 3, it takes the other.

In [66]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  print(f(2))
except Exception as e:
  print("ERROR:", e)
  


ERROR: Abstract tracer value encountered where concrete value is expected.

The problem arose with the `bool` function. 

While tracing the function f at <ipython-input-66-3da05647f18e>:1, this concrete value was not available in Python because it depends on the value of the arguments to f at <ipython-input-66-3da05647f18e>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>


We can fix this by telling JAX to trace the control flow through the function using concrete values of some of its arguments. JAX will then compile different versions, depending on the input values. See below for an example.


In [67]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f2 = jit(f, static_argnums=(0,))

print(f2(5))

-20


Unfortunately, the static argnum method fails with vmap, which passes in different inputs.

In [68]:

xs = jnp.arange(5)
try:
  ys = vmap(f)(xs)
  print('used vmap')
except:
  ys = jnp.array([f(x) for x in xs])
  print('did not use vmap')
print(ys)



did not use vmap
[  0.   3.  12. -12. -16.]


## Side effects

There are a few other subtleties. If your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit'd functions:

In [70]:
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
y1 = f(2)
print(y1)

print('jit version follows')
@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
y2 = f(2)
print(y2)

print('call jitted function a second time')
y2 = f(2)
print(y2)

2
4
4
jit version follows
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
4
call jitted function a second time
4


# Worked example: gradient descent for linear regression

We put some of the above pieces together to show how to implement (batch) gradient descent to minimize squared error on a linear model. The code is based on the [flax JAX tutorial](https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html). We choose a simple example because we will modify this code later.

In [9]:


# Create the predict function from a set of parameters
def make_predict_fun(W,b):
  def predict(x):
    return jnp.dot(W,x)+b
  return predict

# Create the loss from the data points set
def make_mse_fun(x_batched,y_batched): # returns fn(W,b)
  def mse(W,b):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x,y):
      y_pred = make_predict_fun(W,b)(x)
      return jnp.inner(y-y_pred,y-y_pred)/2.0
    # We vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return jax.jit(mse) # And finally we jit the result.

In [10]:
# Set problem dimensions
N = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
Wtrue = random.normal(key, (ydim, xdim))
btrue = random.normal(key, (ydim,))
true_predict_fun = make_predict_fun(Wtrue, btrue)

# Generate data with additional observation noise
X = random.normal(key, (N, xdim))
Ytrue = jax.vmap(true_predict_fun)(X)
Y = Ytrue + 0.1*random.normal(key, (N, ydim))

# Generate MSE for our samples
mse_fun = make_mse_fun(X, Y)

In [28]:
# Initialize estimated W and b with zeros.
What = jnp.zeros_like(Wtrue)
bhat = jnp.zeros_like(btrue)

alpha = 0.3 # Gradient step size
for i in range(101):
  grad_W = jax.grad(mse_fun,0)(What,bhat)
  grad_b = jax.grad(mse_fun,1)(What,bhat)
  What = What - alpha*grad_W
  bhat = bhat - alpha*grad_b 
  if (i%10==0):
    print("Loss step {}: ".format(i), mse_fun(What,bhat))

assert np.allclose(Wtrue, What, atol=1e-1)
assert np.allclose(btrue, bhat, atol=1e-1)


print('loss with true params {}, loss with estimated params {}'.format(
    mse_fun(Wtrue, btrue), mse_fun(What, bhat)))

Loss step 0:  6.5597453
Loss step 10:  0.17232795
Loss step 20:  0.043397333
Loss step 30:  0.024473595
Loss step 40:  0.01707891
Loss step 50:  0.013489492
Loss step 60:  0.011695366
Loss step 70:  0.0107952645
Loss step 80:  0.0103434585
Loss step 90:  0.010116665
Loss step 100:  0.010002807
loss with true params 0.02229204587638378, loss with estimated params 0.010002806782722473


# Pytrees

A Pytree is a a nested datastructure, such as a list or tuple, which contains items (eg arrays or strings) at its leaves. It is useful for representing hierarchical sets of parameters for DNNs (and other structured dsta). 

We can map functions down a pytree in the same way that we can map a function down a list. We can also combine elements in two pytrees that have the same shape to make a third pytree. We illustrate this below, following the [flax Jax tutorial](https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html).

## Simple example

In [20]:
from jax import tree_util

# a simple pytree
t1 = [1, {"k1": 2, "k2": (3, 4)}, 5]

t2 = tree_util.tree_map(lambda x: x*x, t1)
print(t2)


[1, {'k1': 4, 'k2': (9, 16)}, 25]


In [21]:
t3 = tree_util.tree_multimap(lambda x,y: x+y, t1, t2)
print(t3)

[2, {'k1': 6, 'k2': (12, 20)}, 30]


## More complex example: linear regression revisited

In [22]:

# Create the predict function from a set of parameters
def make_predict_pytree(params):
  def predict(x):
    return jnp.dot(params['W'],x)+params['b']
  return predict

# Create the loss from the data points set
def make_mse_pytree(x_batched,y_batched): # returns fn(params)->real
  def mse(params):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x,y):
      y_pred = make_predict_pytree(params)(x)
      return jnp.inner(y-y_pred,y-y_pred)/2.0
    # We vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return jax.jit(mse) # And finally we jit the result.

In [29]:
# Initialize estimated W and b with zeros.
params = {'W': jnp.zeros_like(Wtrue), 'b': jnp.zeros_like(btrue)}
params_true = {'W': Wtrue, 'b': btrue}

mse_pytree = make_mse_pytree(X, Y)
print(mse_pytree(params_true))
print(mse_pytree(params))

print(jax.grad(mse_pytree)(params))

0.022292046
24.97824
{'W': DeviceArray([[-0.039,  0.755,  0.542,  0.36 ,  0.224,  1.651,  1.534,
              -1.342, -0.15 , -1.638],
             [-0.324,  0.141, -0.402,  0.498,  1.829,  4.308,  2.138,
              -2.43 , -0.381, -2.178],
             [ 1.7  , -0.707, -0.656, -0.568,  1.824, -2.194, -0.477,
               0.96 ,  1.622,  1.408],
             [-0.862,  0.321, -0.388, -0.74 , -0.82 ,  0.441,  0.772,
              -1.713, -1.592, -0.557],
             [ 1.338, -0.632, -0.968, -1.127,  1.775,  0.323,  1.405,
              -0.638,  1.077, -0.739]], dtype=float32), 'b': DeviceArray([ 0.036,  1.092, -0.413, -1.389, -0.862], dtype=float32)}


In [30]:
alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', mse_pytree(params_true))
for i in range(101):
  gradients = jax.grad(mse_pytree)(params)
  params = jax.tree_multimap(lambda old,grad: old-alpha*grad, params, gradients)
  if (i%10==0):
    print("Loss step {}: ".format(i), mse_pytree(params))

Loss for "true" W,b:  0.022292046
Loss step 0:  6.5597453
Loss step 10:  0.17232795
Loss step 20:  0.043397333
Loss step 30:  0.024473595
Loss step 40:  0.01707891
Loss step 50:  0.013489492
Loss step 60:  0.011695366
Loss step 70:  0.0107952645
Loss step 80:  0.0103434585
Loss step 90:  0.010116665
Loss step 100:  0.010002807


# Looping constructs

For loops in Python are slow, even when JIT-compiled. However, there are built-in primitives for loops that are fast, as we illustrate below.

## For loops.

The semantics of the for loop function in JAX is as follows:
```
def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val
```
We see that ```val``` is used to accumulate the results across iterations.

Below is an example.

In [71]:
# sum from 1 to N = N*(N+1)/2

def sum_exact(N):
  return int(N*(N+1)/2)

def sum_slow(N):
  s = 0
  for i in range(1,N+1):
    s += i
  return s

N = 10

assert sum_slow(N) == sum_exact(N)

def sum_fast(N):
  s = jax.lax.fori_loop(1, N+1, lambda i,partial_sum: i+partial_sum, 0)
  return s

assert sum_fast(N) == sum_exact(N) 

In [73]:
N = 1000
%timeit sum_slow(N)
%timeit sum_fast(N)

10000 loops, best of 3: 44.1 µs per loop
10 loops, best of 3: 41 ms per loop


In [75]:
N = 100000
%timeit sum_slow(N)
%timeit sum_fast(N)

100 loops, best of 3: 5.04 ms per loop
1 loop, best of 3: 2.88 s per loop


In [79]:
# Let's do more compute per step of the for loop

D = 10
X = jax.random.normal(key, shape=(D,D))

def sum_slow(N):
  s = jnp.zeros_like(X)
  for i in range(1,N+1):
    s += jnp.dot(X, X)
  return s

def sum_fast(N):
  s = jnp.zeros_like(X)
  s = jax.lax.fori_loop(1, N+1, lambda i,s: s+jnp.dot(X,X), s)
  return s

N = 10
assert np.allclose(sum_fast(N), sum_slow(N))

In [81]:
N = 1000
%timeit sum_slow(N)
%timeit sum_fast(N)

1 loop, best of 3: 482 ms per loop
10 loops, best of 3: 46.3 ms per loop


## While loops

Here is the semantics of the JAX while loop


```
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
```

Below is an example.

In [88]:


def sum_slow_while(N):
  s = 0
  i = 0
  while (i <= N):
    s += i
    i += 1
  return s


def sum_fast_while(N):
  init_val = (0,0)
  def cond_fun(val):
    s,i = val
    return i<=N
  def body_fun(val):
    s,i = val
    s += i
    i += 1
    return (s,i)
  val = jax.lax.while_loop(cond_fun, body_fun, init_val)
  s2 = val[0]
  return s2

N = 10
assert sum_slow_while(N) == sum_exact(N)
assert sum_slow_while(N) == sum_fast_while(N)

# Mutation of arrays 

Since JAX is functional, you cannot mutate arrays in place,
since this makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers the functional update functions: `index_update`, `index_add`, `index_min`, `index_max`, and the `index` helper. These are illustrated below. 

Note: If the input values of `index_update` aren't reused, jit-compiled code will perform these operations in-place, rather than making a copy. 
    

In [89]:
# You cannot assign directly to elements of an array.

A = jnp.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
try:
  A[1, :] = 1.0
except:
  print('must use index_update')

must use index_update


In [96]:
from jax.ops import index, index_add, index_update

D = 3
A = 2*jnp.ones((D,D))
print("original array:")
print(A)

A2 = index_update(A, index[1, :], 42.0) # A[1,:] = 42
print("original array:")
print(A) # unchanged
print("new array:")
print(A2)

A3 = A.at[1,:].set(42.0) # A3=np.copy(A),  A3[1,:] = 42
print("original array:")
print(A) # unchanged
print("new array:")
print(A3)

A4 = A.at[1,:].mul(42.0) # A4=np.copy(A),  A4[1,:] *= 42
print("original array:")
print(A) # unchanged
print("new array:")
print(A4)



original array:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
original array:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
new array:
[[ 2.  2.  2.]
 [42. 42. 42.]
 [ 2.  2.  2.]]
original array:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
new array:
[[ 2.  2.  2.]
 [42. 42. 42.]
 [ 2.  2.  2.]]
original array:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
new array:
[[ 2.  2.  2.]
 [84. 84. 84.]
 [ 2.  2.  2.]]


# Implicitly casting lists to vectors

You cannot treat a list of numbers as a vector. Instead you must explicitly create the vector using the np.array() constructor.


In [97]:
# You cannot treat a list of numbers as a vector. 
try:
  S = jnp.diag([1.0, 2.0, 3.0])
except:
  print('must convert indices to np.array')

must convert indices to np.array


In [98]:
# Instead you should explicitly construct the vector.

S = jnp.diag(jnp.array([1.0, 2.0, 3.0]))

# JAX neural net libraries

JAX is a purely functional library, which differs from Tensorflow and
Pytorch, which are stateful. The main advantages of functional programming
are that  we can safely transform the code, and/or run it in parallel, without worrying about
global state changing behind the scenes. The main disadvantage is that code (especially DNNs) can be harder to write.
To simplify the task, various DNN libraries have been designed, as we list below. In this book, we use Flax.

|Name|Description|
|----|----|
|[Stax](https://github.com/google/jax/blob/master/jax/experimental/stax.py)|Barebones library for specifying DNNs|
|[Flax](https://github.com/google/flax)|Library for specifying and training DNNs|
|[Haiku](https://github.com/deepmind/dm-haiku)|Library for specifying DNNs, similar to Sonnet|
|[Trax](https://github.com/google/trax)|Library for specifying and training DNNs, with a focus on sequence models|
|[Objax](https://github.com/google/objax)|Stateful (object-oriented) DNN framework, similar to PyTorch, not compatible with other JAX libraries|


# Other JAX  libraries

There are many other useful JAX libraries, some of which we list below.

|Name|Description|
|----|----|
|[NumPyro](https://github.com/pyro-ppl/numpyro)|Library for (deep) probabilistic modeling|
|[Optax](https://github.com/deepmind/optax)|Library for defining gradient-based optimizers|
|[RLax](https://github.com/deepmind/rlax)|Library for reinforcement learning|
|[Chex](https://github.com/deepmind/chex)|Library for debugging and developing reliable JAX code|