[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/jax_for_the_impatient.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/jax_for_the_impatient.ipynb)

# JAX for the Impatient
**JAX 等于 NumPy on the CPU, GPU, and TPU, 并且带有强大的自动微分能力，用于高性能机器学习研究。**

这里我们讲一下JAX的基础，帮助你学习Flax，在熟悉了这些基础后，我们还是建议你去浏览JAX的文档 [here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) 。

## NumPy API

JAX中有一套NumPy API，先来看一下。

In [1]:
import jax
from jax import numpy as jnp, random

import numpy as np # We import the standard NumPy library 

`jax.numpy` 是一套模拟NumPy的API，由于JAX的随机数生成机制和NumPy完全不同，我们还需要用到 `jax.random` 来生成一些随机数。

我们来一个矩阵乘法的例子：

In [2]:
m = jnp.ones((4,4)) # 生成一个4 * 4 的矩阵
n = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]]) # 2 * 4的矩阵
m

DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)

JAX中的数组类型是DeviceArray，我们可以和NumPy那样进行矩阵乘法。

In [3]:
jnp.dot(n, m).block_until_ready() # Note: yields the same result as np.dot(m)

DeviceArray([[10., 10., 10., 10.],
             [26., 26., 26., 26.]], dtype=float32)

JAX中默认异步执行，所以DeviceArray对象实际上是futures ([more here](https://jax.readthedocs.io/en/latest/async_dispatch.html)) 。 可能矩阵乘法还没计算完成，Python调用已经结束了，所以我们增加了 `block_until_ready()` 来保证Python程序返回最终的计算结果。

JAX的DeviceArray和Numpy NDArray之间可以无缝转换。

In [5]:
x = np.random.normal(size=(4,4)) # 创建一个NumPy数组
jnp.dot(x, m)

DeviceArray([[ 2.       ,  2.       ,  2.       ,  2.       ],
             [ 1.7832031,  1.7832031,  1.7832031,  1.7832031],
             [-1.3183594, -1.3183594, -1.3183594, -1.3183594],
             [ 1.9140625,  1.9140625,  1.9140625,  1.9140625]],            dtype=float32)

如果你在GPU或TPU上运行JAX，使用NumPy数组可能会产生多次复制操作，将数组从CPU复制到GPU/TPU。建议使用JAX数组或者调用 `jax.device_put` 将NumPy数组迁移到加速卡。JAX数组（DeviceArrays）就是在device上进行计算，不涉及到数据迁移，比如`jnp.dot(long_vector, long_vector)` 只会讲最后的结果（标量）从device迁移到host。

In [6]:
x = np.random.normal(size=(4,4))
x = jax.device_put(x)  # device_put()
x

DeviceArray([[-1.9078977 , -0.14710547, -0.72077036, -0.9940185 ],
             [ 0.86262065, -1.0833409 ,  0.13059273, -0.5004832 ],
             [-1.2784858 ,  1.0578346 , -0.71898067,  1.2214077 ],
             [ 0.04497718,  1.4795924 , -0.17639156, -1.4458165 ]],            dtype=float32)

怎杨将JAX数组转换为NumPy数组？so easy

In [7]:
x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
np.array(x)

array([[1., 2., 3., 4.],
       [5., 6., 7., 8.]], dtype=float32)

## 不可修改（Immutability）

JAX本质上是函数式灵魂，导致JAX数组是不可变的，不可能对JAX数组进行原地（in-place）赋值或者切片赋值，并且，函数不应该读写全局状态。

In [8]:
x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
updated = x.at[0, 0].set(3.0) # 如果执行 x[0,0] = 3.0 会报错
print("x: \n", x) # 注意x并没有被修改
print("updated: \n", updated)

x: 
 [[1. 2. 3. 4.]
 [5. 6. 7. 8.]]
updated: 
 [[3. 2. 3. 4.]
 [5. 6. 7. 8.]]


`at[]`除了支持`set()`，还支持 `add`, `mul`, `min`, `max`。

## 随机数

JAX和NumPy的一大区别就是，随机数。详情建议阅读 JAX文档 [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers) ，这里我们直接转述了:

*JAX中的伪随机数生成算法和NumPy不同，最主要的是，JAX不会隐式的修改随机数状态，必须用户显式修改。*


由于采用了更现代的随机数生成算法，JAX中的随机数状态很简单，就是一个包含两个unsigned-int32s的向量，这个向量被称为key。

In [9]:
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

如果你用相同的key来生成随机数，也就是等价于用用相同的随机数状态来生成随机数，得到的结果是什么？同一个随机数生成函数得到的随机数一直不变！你必须手动修改随机数状态，也就是修改key，怎么改呢？很简单，split就行了。

In [10]:
for i in range(3):
    print("Printing the random number using key: ", key, " gives: ", random.normal(key,shape=(1,))) # 同一个key，同一个随机数生成算法，结果相同

Printing the random number using key:  [0 0]  gives:  [-0.20584236]
Printing the random number using key:  [0 0]  gives:  [-0.20584236]
Printing the random number using key:  [0 0]  gives:  [-0.20584236]


In [11]:
print("old key", key, "--> normal", random.normal(key, shape=(1,)))
key, subkey = random.split(key)  # 创建新的key
print("    \---SPLIT --> new key   ", key, "--> normal", random.normal(key, shape=(1,)) )
print("             \--> new subkey", subkey, "--> normal", random.normal(subkey, shape=(1,)) )

old key [0 0] --> normal [-0.20584236]
    \---SPLIT --> new key    [4146024105  967050713] --> normal [0.14389051]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515285]


split也可以得到多个subkeys

In [12]:
key, *subkeys = random.split(key, 4)
key, subkeys

(DeviceArray([3306097435, 3899823266], dtype=uint32),
 [DeviceArray([147607341, 367236428], dtype=uint32),
  DeviceArray([2280136339, 1907318301], dtype=uint32),
  DeviceArray([ 781391491, 1939998335], dtype=uint32)])

## 梯度和自动微分（autodiff）

如果向全面了解JAX的自动微分机制，还是看JAX文档 [Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).

Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix. You can read more in the JAX [cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobian-vector-products-jvps-aka-forward-mode-autodiff) [notebook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff) mentioned above.

### 梯度

JAX提供了强大的自动微分功能，这也是函数式编程的优势，由于对函数求导本质上是无状态的。 

考虑后一个简单的函数 $f:\mathbb{R}^n\rightarrow\mathbb{R}$

$$f(x) = \frac{1}{2} x^T x$$

它的导函数是:
$$\nabla f(x) = x$$

In [13]:
key = random.PRNGKey(0)
def f(x):
  return jnp.dot(x.T,x)/2.0

v = jnp.ones((4,))
f(v)

DeviceArray(2., dtype=float32)

JAX计算导函数是如此简单， `jax.grad`，注意只能作用于输出标量的函数。

下面对f求导，

In [14]:
v = random.normal(key,(4,))
print("Original v:")
print(v)
print("Gradient of f taken at point v")
print(jax.grad(f)(v)) # should be equal to v !

Original v:
[ 1.8160864  -0.75487745  0.3398885  -0.53483075]
Gradient of f taken at point v
[ 1.8160864  -0.75487745  0.3398885  -0.53483075]


前面讲过， `jax.grad` 只能用于返回标量的函数，虽然对深度学习已经足够了，但是JAX也支持通用的向量函数，可以使用功能更强大的原语 Jacobian-Vector product - `jax.jvp` - 和 Vector-Jacobian product - `jax.vjp`。

### Jacobian-Vector product

Let's consider a map $f:\mathbb{R}^n\rightarrow\mathbb{R}^m$. As a reminder, the differential of f is the map $df:\mathbb{R}^n \rightarrow \mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)$ where $\mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)$ is the space of linear maps from $\mathbb{R}^n$ to $\mathbb{R}^m$ (hence $df(x)$ is often represented as a Jacobian matrix). The linear approximation of f at point $x$ reads:
$$f(x+v) = f(x) + df(x)\bullet v + o(v)$$

The $\bullet$ operator means you are applying the linear map $df(x)$ to the vector v.

Even though you are rarely interested in computing the full Jacobian matrix representing the linear map $df(x)$ in a standard basis, you are often interested in the quantity $df(x)\bullet v$. This is exactly what `jax.jvp` is for, and `jax.jvp(f, (x,), (v,))` returns the tuple:
$$(f(x), df(x)\bullet v)$$

Let's use a simple function as an example: $f(x) = \frac{1}{2}({x_1}^2, {x_2}^2, \ldots, {x_n}^2)$ where we know that $df(x)\bullet h = (x_1h_1, x_2h_2,\ldots,x_nh_n)$. Hence using `jax.jvp` with $h= (1,1,\ldots,1)$ should return $x$ as an output.

In [16]:
def f(x):
  return jnp.multiply(x,x)/2.0

x = random.normal(key, (5,))
v = jnp.ones(5)
print("(x,f(x))")
print((x,f(x)))
print("jax.jvp(f, (x,),(v,))")
print(jax.jvp(f, (x,),(v,)))

(x,f(x))
(DeviceArray([ 0.18784378, -1.2833427 , -0.27109176,  1.2490592 ,
              0.24446994], dtype=float32), DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278],            dtype=float32))
jax.jvp(f, (x,),(v,))
(DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278],            dtype=float32), DeviceArray([ 0.18784378, -1.2833427 , -0.27109176,  1.2490592 ,
              0.24446994], dtype=float32))


### Vector-Jacobian product
Keeping our $f:\mathbb{R}^n\rightarrow\mathbb{R}^m$ it's often the case (for example, when you are working with a scalar loss function) that you are interested in the composition $x\rightarrow\phi\circ f(x)$ where $\phi :\mathbb{R}^m\rightarrow\mathbb{R}$. In that case, the gradient reads:
$$\nabla(\phi\circ f)(x) = J_f(x)^T\nabla\phi(f(x))$$

Where $J_f(x)$ is the Jacobian matrix of f evaluated at x, meaning that $df(x)\bullet v = J_f(x)v$.

`jax.vjp(f,x)` returns the tuple:
$$(f(x),v\rightarrow v^TJ_f(x))$$

Keeping the same example as previously, using $v=(1,\ldots,1)$, applying the VJP function returned by JAX should return the $x$ value:

In [17]:
(val, jvp_fun) = jax.vjp(f,x)
print("x = ", x)
print("v^T Jf(x) = ", jvp_fun(jnp.ones((5,)))[0])

x =  [ 0.18784378 -1.2833427  -0.27109176  1.2490592   0.24446994]
v^T Jf(x) =  [ 0.18784378 -1.2833427  -0.27109176  1.2490592   0.24446994]


## 使用jit和vmap加速代码执行



### Jit

JAX底层用的是XLA编译器，也支持用户手动使用JIT编译来进一步加速，`@jit`。

In [15]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

v = random.normal(key, (1000000,))
%timeit selu(v).block_until_ready()

1.69 ms ± 612 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


使用JIT编译:

In [16]:
selu_jit = jax.jit(selu)
%timeit selu_jit(v).block_until_ready()

240 µs ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---
### 向量化（Vectorization）

JAX允许你写的函数只作用于一个样本，由JAX来自动进行批处理。

In [17]:
mat = random.normal(key, (15, 10))
batched_x = random.normal(key, (5, 10)) # Batch size在维度0
single = random.normal(key, (10,))

def apply_matrix(v):
  return jnp.dot(mat, v)

print("Single apply shape: ", apply_matrix(single).shape)
print("Batched example shape: ", jax.vmap(apply_matrix)(batched_x).shape)

Single apply shape:  (15,)
Batched example shape:  (5, 15)


## 线性回归例子

让我们实现一个线性回归的例子，训练集 $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$，我们想找到一组最优的参数 $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ 来让预测结果 $f_{W,b}(x)=Wx+b$ 和真实标签之间的MSE最小:
$$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$


In [18]:
# Linear feed-forward.
def predict(W, b, x):
  return jnp.dot(x, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [19]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [20]:
# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(W, b, x, y, lr):
  W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)
  return W, b

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse(W_hat, b_hat, x_samples, y_samples))

Loss for "true" W,b:  0.02363973
Loss step 0:  10.971258
Loss step 5:  1.0783367
Loss step 10:  0.37938032
Loss step 15:  0.17768295
Loss step 20:  0.09448608
Loss step 25:  0.05415666
Loss step 30:  0.034192722
Loss step 35:  0.023958156
Loss step 40:  0.018536853
Loss step 45:  0.015428396
Loss step 50:  0.01385448
Loss step 55:  0.012978616
Loss step 60:  0.012341755
Loss step 65:  0.01210101
Loss step 70:  0.011895995
Loss step 75:  0.011840537
Loss step 80:  0.011756034
Loss step 85:  0.011716946
Loss step 90:  0.011748294
Loss step 95:  0.01174049
Loss step 100:  0.011691327


This is obviously an approximate solution to the linear regression problem (solving it would require a bit more work!), but here you have all the tools you would need if you wanted to do it the proper way.

## Refining a bit with pytrees

Here we're going to elaborate on our previous example using JAX pytree data structure.

### Pytrees basics

JAX中到处都有pytree的身影，Flax也是，更多内容建议看 [pytree page](https://jax.readthedocs.io/en/latest/pytrees.html) :

*In JAX, a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts (JAX can be extended to consider other container types as pytrees, see Extending pytrees below). A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.*

```python
[1, "a", object()] # 3 leaves: 1, "a" and object()

(1, (2, 3), ()) # 3 leaves: 1, 2 and 3

[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves: 1, 2, 3, 4, 5
```

JAX provides a few utilities to work with pytrees that live in the `tree_util` package.

In [24]:
from jax import tree_util

t = [1, {"k1": 2, "k2": (3, 4)}, 5]

You will often come across `tree_map` function that maps a function f to a tree and its leaves. We used it in the previous section to display the shapes of the model's parameters.

In [25]:
tree_util.tree_map(lambda x: x*x, t)

[1, {'k1': 4, 'k2': (9, 16)}, 25]

Instead of applying a standalone function to each of the tree leaves, you can also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function.

In [26]:
t2 = tree_util.tree_map(lambda x: x*x, t)
tree_util.tree_map(lambda x,y: x+y, t, t2)

[2, {'k1': 6, 'k2': (12, 20)}, 30]

### Linear regression with Pytrees

Whereas our previous example was perfectly fine, we can see that when things get more complicated (as they will with neural networks), it will be harder to manage parameters of the models as we did.

Here we show an alternative based on pytrees, using the same data from the previous example.
Now, our `params` is a pytree containing both the `W` and `b` entries.

In [27]:
# Linear feed-forward that takes a params pytree.
def predict_pytree(params, x):
  return jnp.dot(x, params['W']) + params['b']

# Loss function: Mean squared error.
def mse_pytree(params, x_batched,y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x,y):
    y_pred = predict_pytree(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

# Initialize estimated W and b with zeros. Store in a pytree.
params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}

The great thing is that JAX is able to handle differentiation with respect to pytree parameters:

In [28]:
jax.grad(mse_pytree)(params, x_samples, y_samples)

{'W': DeviceArray([[-1.9287349e+00,  4.2963755e-01,  7.1613449e-01,
                2.1056123e+00,  5.0405121e-01, -2.4983375e+00,
               -6.3854176e-01, -2.2620213e+00, -1.3365206e+00,
               -2.0426039e-01],
              [ 1.1999468e+00, -9.4563609e-01, -1.0878400e+00,
               -7.0340711e-01,  3.3224609e-01,  1.7538791e+00,
               -7.1916544e-01,  1.0927428e+00, -1.4491037e+00,
                5.9715635e-01],
              [-1.4826509e+00, -7.6116532e-01,  2.2319858e-01,
               -3.0391946e-01,  3.0397055e+00, -3.8419428e-01,
               -1.8290073e+00, -2.3353369e+00, -1.1087127e+00,
               -7.7453995e-01],
              [ 8.2374442e-01, -9.9650609e-01, -7.6030111e-01,
                6.3919222e-01, -6.0864899e-02, -1.0859716e+00,
                1.2923398e+00, -4.9342898e-01, -1.4711156e-03,
                1.2977618e+00],
              [-4.5656446e-01, -1.3063025e-01, -3.9179009e-01,
                2.1743817e+00, -5.3948693e-02,  

Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_map`:

In [29]:
# Always remember to jit!
@jax.jit
def update_params_pytree(params, learning_rate, x_samples, y_samples):
  params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(mse_pytree)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse_pytree({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  params = update_params_pytree(params, learning_rate, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse_pytree(params, x_samples, y_samples))

Loss for "true" W,b:  0.023639774
Loss step 0:  11.096583
Loss step 5:  1.1743388
Loss step 10:  0.32879353
Loss step 15:  0.1398177
Loss step 20:  0.07359565
Loss step 25:  0.04415301
Loss step 30:  0.029408678
Loss step 35:  0.021554656
Loss step 40:  0.017227933
Loss step 45:  0.014798875
Loss step 50:  0.013420242
Loss step 55:  0.0126327025
Loss step 60:  0.0121810865
Loss step 65:  0.011921468
Loss step 70:  0.011771992
Loss step 75:  0.011685831
Loss step 80:  0.011636148
Loss step 85:  0.011607475
Loss step 90:  0.011590928
Loss step 95:  0.011581394
Loss step 100:  0.011575883


Besides `jax.grad()`, another useful function is `jax.value_and_grad()`, which returns the value of the input function and of its gradient.

To switch from `jax.grad()` to `jax.value_and_grad()`, replace the training loop above with the following:

In [None]:
# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(mse_pytree)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)

That's all you needed to know to get started with Flax! To dive deeper, we very much recommend checking the JAX [docs](https://jax.readthedocs.io/en/latest/index.html).