# 🔪 JAX - The Sharp Bits 🔪

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)

*levskaya@ mattjj@*

当你漫步在意大利的田园乡间，当地人会告诉你 __JAX__ 的含义是 _"una anima di pura programmazione funzionale（纯函数编程之魂）"_。

__JAX__ 是一门对数值程序进行各种__转换（transform）__的语言。 __JAX__ 也可以对数值程序进行编译，在CPU、GPU或TPU上执行。并不是所有的函数都支持JAX，下面我们会一一讲到。

In [2]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

## 🔪 纯函数（Pure functions）

JAX 转换和编译只可以用于纯函数：
* 输入相同，纯函数的输出肯定相同
* 纯函数不会**读写（read && write）**函数外部的状态，比如不修改全局变量的值和伪随机数状态

下面是几个非纯函数的例子，此时JAX的执行结果会和Python解释器有所不同，JAX不保证结果的正确性。总之，JAX应该只作用于纯函数。

In [3]:
def impure_print_side_effect(x):
  print("Executing function")  # print语句🙅🏻‍，因为有副作用（side-effect）
  return x

# 函数第一次执行时，有副作用
print ("First call: ", jit(impure_print_side_effect)(4.))

# 接下来执行函数，只要传参的type和shape相同，则没有副作用，
# 因为JAX此时调用的是缓存的函数编译结果
print ("Second call: ", jit(impure_print_side_effect)(5.))

# 如果传参的type或shape发生变化，JAX会对该函数重新编译执行
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [4]:
g = 0.
def impure_uses_globals(x):
  return x + g  # 函数的结果受到全局变量影响，相同的传参不能保证相同的返回结果

# JAX在函数第一次执行时记录了全局变的值
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # 更新g的值

# 接下来调用函数，使用的是缓存的g值（0）
print ("Second call: ", jit(impure_uses_globals)(5.))

# 当传参的type或shape发生变化，JAX重新编译函数，
# 此时使用最新的g值
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [5]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


如果函数内部有状态变量，只要它不对外部状态进行读写就没事，也属于纯函数:

In [6]:
def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

50.0


不推荐想用`jit`编译或任何控制流原语的函数内部使用迭代器（iterator）。 原因是迭代器是一个Python对象，会引入状态。下面的例子大部分会出错，少部分不出错的结果也不可信。

In [7]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)  
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))    
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0


## 🔪 原地更新（In-Place Updates）

在Numpy中，你可以这么做:

In [8]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# 原地修改数组的值
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


如果我们对JAX数组进行原地更新，报错 (☉_☉)

In [9]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
try:
  jax_array[1, :] = 1.0
except Exception as e:
  print("Exception {}".format(e))

Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


如果允许变量原地修改值，会使得程序分析和转换非常困难，JAX要求函数是纯函数，同时，JAX提供了一种更新数组的方式，当然不是原地修改，而是返回一个更新后的数组， [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).

️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction.  See the below __Control Flow__ Section for more information on this limitation.

### 数组更新： `x.at[idx].set(y)`

In [10]:
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)

updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


上面的操作返回的是一个新数组，原数组并没有改变。

In [11]:
print("original array unchanged:\n", jax_array)

original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


不过，在jit编译过的代码内部，如果`x.at[idx].set(y)` 中的输入`x`不会再被使用，编译器会将数组更新操作优化为原地（in-place）进行。

### 数组更新时的计算操作

使用`at[index]`对数组进行更新时，除了赋值操作，还支持计算，比如下面的例子，对原值加7:

In [12]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)  # add()
print("new array post-addition:")
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


更多操作，可以参考 [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).

## 🔪 索引越界

在Numpy中，如果索引越界，会报错：

In [13]:
try:
  np.arange(10)[11]
except Exception as e:
  print("Exception {}".format(e))

Exception index 11 is out of bounds for axis 0 with size 10


但是，在加速卡上执行代码时抛出一个越界错误比较困难，甚至难以实现。因此，JAX 对于数组索引越界并不报错，看下面的例子，当检索数组值的下标越界，返回的是数组最后一个元素:

In [14]:
jnp.arange(10)[11]

DeviceArray(9, dtype=int32)

虽然索引越界不报错，但是此时的结果是无法保证正确的，小心处理。

## 🔪 非数组作为输入: NumPy vs. JAX

NumPy支持Python list或tuple作为输入:

In [15]:
np.sum([1, 2, 3])

6

JAX则会报错:

In [16]:
try:
  jnp.sum([1, 2, 3])
except TypeError as e:
  print(f"TypeError: {e}")

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


JAX故意这么设计的，因为在JIT编译场景中无法trace，list和tuple没有shape和dtype。

看下面的例子，将x转为JAX array，来支持输入是list:

In [17]:
def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)

DeviceArray(45, dtype=int32)

结果没问题，但是要小心，这样做会降低性能，因为在JAX的tracing和JIT编译模型逻辑中，Python list或tuple中的每个元素都被看做一个单独的JAX变量，分别进行处理和传输到device，而不是作为一个JAX 变量来处理，可以看下jaxpr就明白了:

In [22]:
make_jaxpr(permissive_sum)(x)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m b[35m:i32[][39m c[35m:i32[][39m d[35m:i32[][39m e[35m:i32[][39m f[35m:i32[][39m g[35m:i32[][39m h[35m:i32[][39m i[35m:i32[][39m
    j[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mk[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] a
    l[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    m[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] c
    n[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] d
    o[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] e
    p[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] f
    q[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] g
    r[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] h
    s[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] i
    t[35m:i32[][39m 

可以看到，list中每个元素都被单独当做一个JAX变量，进行tracing和编译，这使得JAX变量数量随着list增加而线性增加，是不是有点出乎意料？JAX并不会隐式的将Python list或tuple转为数组来提高性能，希望你能加深理解避免这样做。

所以，还是你自己手动将Python list或tuple转为JAX数组吧：

In [23]:
jnp.sum(jnp.array(x))

DeviceArray(45, dtype=int32)

## 🔪 随机数

> _如果我们把图书馆中那些因为随机数`rand()`使用不当而导致结果受质疑的论文都拿走，好吧，书架上会估计会空一大半。_ - Numerical Recipes

### 随机数生成器（RNGs）和状态（State）

在numpy或者其他库中，你已经习惯于使用有状态的伪随机数生成器（PRNGs），背后的细节你根本不需要关心:

In [24]:
print(np.random.random())
print(np.random.random())
print(np.random.random())

0.3613510116453882
0.5651344398181016
0.019768584014015933


numpy使用的伪随机数生成算法是 [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister)。背后有一个伪随机状态向量，

In [27]:
np.random.seed(0)
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers..., 
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

In [28]:
print(rng_state)

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 1272276355, 3172048492,
       3267256201, 2332199830, 1975469449,  392443598, 1132453229,
       2900699076, 1998300999, 3847713992,  512669506, 1227792182,
       1629110240,  112303347, 2142631694, 3647635483, 1715036585,
       2508091258, 1355887243, 1884998310, 3906360088,  952450269,
       3647883368, 3962623343, 3077504981, 2023096077, 3791588343,
       3937487744, 3455116780, 1218485897, 1374508007, 2815569918,
       1367263917,  472908318, 2263147545, 1461547499, 4126813079,
       2383504810,   64750479, 2963140275, 1709368

每次产生一个新的随机数，状态都会自动更新:

In [31]:
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state) 
# --> ('MT19937', array([1499117434, 2949980591, 2242547484, 
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

随机数状态在numpy内部自动更新，在多线程、多进程甚至跨设备场景中，很那去推测出何时进行了状态更新。对用户来说就是一个迷。

Mersenne Twister PRNG也有其他的问题，比如状态向量太大，足足有2.5kb，带来[initialization issues](https://dl.acm.org/citation.cfm?id=1276928)。它也[没有通过](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) 现代的BigCrush测试，通常很慢。

### JAX PRNG

JAX则使用了一种显式的PRNG方法，随机数状态不会隐式更新。一切尽在用户掌控，JAX使用的为随机数生成算法是[Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) ，它是可以**分割的（splittable）**。也就是说，我们可以对PRNG状态进行 __fork__ 来生成新的PRNG。

随机数状态用两个unsigned-int32s来描述，我们称这个只有两个元素的向量为**key**:

In [33]:
from jax import random
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

同样的，JAX的随机数函数需要用到PRNG状态来生成随机数，但是随机数函数不会去修改随机数状态！随机数状态只有用户可以手动更新！

如果用户一直不更新随机数状态，那么随机数就不随机了：:

In [34]:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)

[-0.20584236]
[0 0]
[-0.20584236]
[0 0]


怎么更新随机数状态呢？使用 __split__ 来得到一个新的 __subkeys__ ，传给随机数生成函数:

In [35]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [0 0]
    \---SPLIT --> new key    [4146024105  967050713]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515285]


刚才split得到的key继续作为下一次split的传参，来产生新的key和subkey:

In [36]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [4146024105  967050713]
    \---SPLIT --> new key    [2384771982 3928867769]
             \--> new subkey [1278412471 2182328957] --> normal [-0.5866531]


`split`可以产生多个subkey:

In [37]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))

[-0.3753332]
[0.9864523]
[0.14553195]


## 🔪 控制流（Control Flow）

### ✔ python control_flow + autodiff ✔

如果你的Python函数中包含控制流（if-else，for loop），使用 `grad` 没问题。

In [39]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!

12.0
-4.0


### python control flow + JIT

但是如果Python函数中有控制流，你想用 `jit` ，问题就有点复杂了。

下面的例子可以运行:

In [40]:
@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))

24


这个例子也可以:

In [41]:
@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))

6.0


但是下面的例子，就不行了：

这是个纯函数啊？为啥不行，

In [42]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))

Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_358787/2259617891.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


__What gives!?__

当我们用 `jit`编译某个函数式，我们希望编译后的函数可以适用于很多传参，这样编译后的代码就可以缓存起来，当调用函数时不需要再次编译而是直接执行第一次JIT编译得到的代码。

举个例子，某个函数的传参是`jnp.array([1., 2., 3.], jnp.float32)`，我们希望 `@jit` 后，这个函数还可以用于 `jnp.array([4., 5., 6.], jnp.float32)` 。

JAX实际上是trace的抽象值（abstract values），这样才能泛化到很多同shape同dtype的变量，JAX有[多级抽象](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py)，不同的转换（transformations）可以使用不同的抽象级别。

默认情况下， `jit` 使用的是 `ShapedArray` 抽象级别，抽象值代表了具有相同shape和dtype的一类数组，比如，抽象值是 `ShapedArray((3,), jnp.float32)`，表示这个函数可以用于shape是(3,) 类型是float32的所有数组。

但是，如果抽象值是 `ShapedArray((), jnp.float32)` ，当碰到 `if x < 3`语句，表达式 `x < 3` 结果对应的抽象值是 `ShapedArray((), jnp.bool_)` ，代表的是 `{True, False}`。When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

凡事都有两面性，如果想从更高级别进行抽象，使得JIT编译的函数适用范围更广，那么对Python函数的约束也就越大，不是随随便编一个Python函数就可以的。

JAX允许用户自己进行权衡，比如上面的f(x)，如果你非要进行JIT编译，可以把x设置为static，来trace具体的值，但是要注意此时由于trace的具体的值，只要改变x的值，调用JIT编译后的f(x)，还需要重新编译一遍:

In [56]:
def f(x):
  print("when you call f(x)")
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))

when you call f(x)
12.0


In [57]:
f(3.)  # 由于传入了3.而不是2.，函数重新编译了

when you call f(x)


DeviceArray(-12., dtype=float32, weak_type=True)

In [58]:
f(3.)  # 函数不需要重新编译

DeviceArray(-12., dtype=float32, weak_type=True)

In [59]:
f(4.)  # 由于传入了4.而不是3.，函数重新编译饿了

when you call f(x)


DeviceArray(-16., dtype=float32, weak_type=True)

In [60]:
f(3.)  # 没有重新编译，已经缓存过了

DeviceArray(-12., dtype=float32, weak_type=True)

再看另一个例子，这次包含了for循环，JIT编译后的函数只支持n=2，只要改变n的值，函数就要重新编译：

In [61]:
def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)

DeviceArray(5., dtype=float32)

️⚠️ **functions with argument-__value__ dependent shapes**

These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok).  As a trivial example, let's make a function whose output happens to depend on the input variable `length`.

In [33]:
def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))

bad_example_jit = jit(example_fun)
# this will fail:
try:
  print(bad_example_jit(10, 4))
except Exception as e:
  print("Exception {}".format(e))
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))

[4. 4. 4. 4. 4.]
Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


上面的例子，在此可以看到，如果`length`不会变，使用`static_argnums` 没问题，如果`length`变来变去，那么每次都要重新编译函数！

不要在要进行jit编译的函数内部使用print语句，打印不出值的。

In [63]:
@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


DeviceArray(4, dtype=int32, weak_type=True)

### 控制流原语（control flow primitives）

如果我们既想要在函数内部使用控制流语句，又想避免函数重编译，那么就不能用Python中的控制流语句了，要使用`lax`中的控制流原语:

 - `lax.cond`  _differentiable_
 - `lax.while_loop` __fwd-mode-differentiable__
 - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.
 - `lax.scan` _differentiable_

#### cond

语法含义等价于下面这段Python代码:

```python
def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
```

In [64]:
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)

DeviceArray([-1.], dtype=float32)

#### while_loop

语法含义等价于下面这段Python代码:
```
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
```

In [65]:
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

DeviceArray(10, dtype=int32, weak_type=True)

#### fori_loop
语法含义等价于下面这段Python代码:

```
def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
```

In [37]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)

DeviceArray(45, dtype=int32, weak_type=True)

#### Summary

$$
\begin{array} {r|rr} 
\hline \
\textrm{construct} 
& \textrm{jit} 
& \textrm{grad} \\
\hline \
\textrm{if} & ❌ & ✔ \\
\textrm{for} & ✔* & ✔\\
\textrm{while} & ✔* & ✔\\
\textrm{lax.cond} & ✔ & ✔\\
\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.scan} & ✔ & ✔\\
\hline
\end{array}
$$

<center>

$\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop

</center>

## 🔪 NaNs

### Debugging NaNs

如果你要trace哪里发生了 NaNs，可以开启 NaN-checker:

* setting the `JAX_DEBUG_NANS=True` environment variable;

* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;

* adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;

This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.

There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:

```
In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return DeviceArray(device_buffer, *result_shape)

FloatingPointError: invalid value
```

The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows.

```
In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...
```

When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.

⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!

⚠️ The NaN-checker doesn't work with `pmap`. To debug nans in `pmap` code, one thing to try is replacing `pmap` with `vmap`.

## 🔪 双精度 (64bit) 

当前，JAX默认使用单精度浮点数，在大多数机器学习场景中没问题，但是要注意，下面生成的x还是单精度：

In [66]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype

dtype('float32')

如果你非要使用双精度浮点数，可以这么做:

1. 设置一个环境变量 `JAX_ENABLE_X64=True`.

2. You can manually set the `jax_enable_x64` configuration flag at startup:

   ```python
   # again, this only works on startup!
   from jax.config import config
   config.update("jax_enable_x64", True)
   ```

3. You can parse command-line flags with `absl.app.run(main)`

   ```python
   from jax.config import config
   config.config_with_absl()
   ```

4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use

   ```python
   from jax.config import config
   if __name__ == '__main__':
     # calls config.config_with_absl() *and* runs absl parsing
     config.parse_flags_with_absl()
   ```

Note that #2-#4 work for _any_ of JAX's configuration options.

We can then confirm that `x64` mode is enabled:

In [67]:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')

dtype('float32')

### Caveats
⚠️ XLA 所有后端都不支持64-bit的卷积操作!

## 🔪 Miscellaneous Divergences from NumPy

While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.
Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.

- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.
- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).

  Here is an example of an unsafe cast with differing results between NumPy and JAX:
  ```python
  >>> np.arange(254.0, 258.0).astype('uint8')                                                
  array([254, 255,   0,   1], dtype=uint8)

  >>> jnp.arange(254.0, 258.0).astype('uint8')                                               
  DeviceArray([254, 255, 255, 255], dtype=uint8)
  ```
  This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.


## Fin.

If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!