In [None]:
import numpy as np

import jax
from jax import jit
from jax import lax
from jax import random
import jax.numpy as jnp

# 🔪 Pure functions

In [None]:
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

In [None]:
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

In [None]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

In [None]:
def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

In [None]:
import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
print(func11(jnp.arange(16), 5.)) # expected result 120
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

In [None]:
# 我尝试的累加函数
@jit
def my_sum(i:float, arr: jnp.ndarray):
    print("JIT compiled")
    init = i
    def for_body(state, x):
        Ignored = None
        return x + state, Ignored
    return lax.scan(for_body, init, arr)[0]
print(my_sum(0., jnp.arange(10))) # expected result 45
print(my_sum(0., jnp.arange(10))) # unexpected result 0

In [None]:
# lax.cond
array_operand = jnp.array([2.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)

# 🔪 In-place updates

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0

In [None]:
jax_array = jnp.array([10, 20])
jax_array_new = jax_array
jax_array_new += 10
print(jax_array_new)  # `jax_array_new` is rebound to a new value [20, 30], but...
print(jax_array)      # the original value is unodified as [10, 20] !

numpy_array = np.array([10, 20])
numpy_array_new = numpy_array
numpy_array_new += 10
print(numpy_array_new)  # `numpy_array_new is numpy_array`, and it was updated
print(numpy_array)      # in-place, so both are [20, 30] !

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
print(id(jax_array))  # id of the first row
jax_array = jax_array.at[1, :].set(1.0)
print(id(jax_array))  # id of the first row
print("updated array:\n", jax_array)

However, inside jit-compiled code, if the input value x of x.at[idx].set(y) is not reused, the compiler will optimize the array update to occur in-place.

In [None]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.) # ::2 是 Python 中标准的切片 (slicing) 语法
print("new array post-addition:")
print(new_jax_array)

# 🔪 Out-of-bounds indexing

In [None]:
jnp.arange(10.0).at[11].get(mode='fill', fill_value=-1)

In [None]:
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)

Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

# 🔪 Non-array inputs: NumPy vs. JAX

In [None]:
np.sum([1, 2, 3]) # works
# jnp.sum([1, 2, 3]) # throws error

In [None]:
from jax import make_jaxpr

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
print(permissive_sum(x))
make_jaxpr(permissive_sum)(x)

# 🔪 Dynamic shapes

In [None]:
def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

In [None]:
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))

In [None]:
@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0)#.sum()

print(nansum_2(x), jnp.sum(nansum_2(x)))

# 🔪 Double (64bit) precision

In [None]:
x = random.uniform(random.key(0), (1000,), dtype=jnp.float32)
print(x.dtype)
#y = random.uniform(random.key(0), (1000,), dtype=jnp.float64) # 会报错
#y.dtype
jax.config.update("jax_enable_x64", True)
y = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
print(y.dtype)
jax.config.update("jax_enable_x64", False)

# 🔪 Miscellaneous divergences from NumPy

In [None]:
np.arange(254.0, 258.0).astype('uint8')

In [None]:
jnp.arange(254.0, 258.0).astype('uint8')

In [None]:
import jax.numpy as jnp
subnormal = jnp.float32(1E-45) # 这个值是 subnormal 的 也就是太小的量
print(subnormal)  # subnormals are representable
print(subnormal + 0)  # but are flushed to zero within operations

# 🔪 NaNs

In [None]:
import jax
jax.config.update("jax_debug_nans", True)
import jax.numpy as jnp

jnp.divide(0., 0.)

In [None]:
import jax
jax.config.update("jax_debug_nans", True)
from jax import jit
import jax.numpy as jnp

@jit
def f(x, y):
    a = x * y
    b = (x + y) / (x - y)
    c = a + 2
    return a + b * c

x = jnp.array([2., 0.])

y = jnp.array([3., 0.])

f(x, y)

# 注意：1./0. 是 inf
# 而 0./0. 是 nan