In [1]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import rcParams

rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

# Pure Functions

All the input data is passed through the function parameters, all the results are output through the function results.

In [2]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x

# The side-effects appear during the first run (Compilation)
print("First call: ", jit(impure_print_side_effect)(4.))

# Susequent runs with parameters of same type and 
# shape may not show the side-effect
# JAX invokes a cached compilation of the function.
print("Second call: ", jit(impure_print_side_effect)(5.))

# Jax re-runs the python function
# when the type or shape of the argument changes.
print("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [3]:
g = 0.
def impure_uses_globals(x):
    return x + g

# JAX captures the value of the global during the first fun.
print("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global

# Subsequent runs may silently use the cached value of the globals
print("Second call: ", jit(impure_uses_globals)(5.))

# This will end up reading the latest value of the global
print("Third call, different type: ", jit(impure_uses_globals)(jnp.array([5.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [15.]


In [4]:
g = 0.
def impure_saves_global(x):
    global g
    g = x
    return x

print("First call: ", jit(impure_saves_global)(4.))
print("Saved global: ", g)

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


In [5]:
def pure_uses_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state['even' if i % 2 == 0 else 'odd'] += x
    return state['even'] + state['odd']
    
print(jit(pure_uses_internal_state)(5.))

50.0


It is not recommented to use iterators in any JAX function you want to JIT or in any control-flow primitive.

In [6]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

45
0


In [7]:
# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(16,) ] 1.0
      d e = scan[ jaxpr={ lambda  ; a b c d.
                          let e = convert_element_type[ new_dtype=float32
                                                        weak_type=False ] c
                              f = mul e d
                              g = convert_element_type[ new_dtype=float32
                                                        weak_type=False ] b
                              h = add g f
                              i = convert_element_type[ new_dtype=float32
                                                        weak_type=False ] a
                              j = add h i
                          in (j, b) }
                  length=16
                  linear=(False, False, False, False)
                  num_carry=1
                  num_consts=1
                  reverse=False
                  unroll=1 ] b 0.0 a c
  in (d, e

In [8]:
# lax.cond
array_operand = jnp.arange(16)
print(lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand))
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16]


# In-Place Updates

In [9]:
numpy_array = np.zeros((3, 3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


In [10]:
jax_array = jnp.zeros((3, 3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
try:
    jax_array[1, :] = 1.0
except Exception as e:
    print("Exception {}".format(e))

Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?


## index_update

In [11]:
from jax.ops import index, index_add, index_update

In [12]:
jax_array = jnp.zeros((3, 3))
print("original array:")
print(jax_array)

new_jax_array = index_update(jax_array, index[1, :], 1.)

print("old array unchanged:")
print(jax_array)

print("new array:")
print(new_jax_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
old array unchanged:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
new array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


## index_add

In [13]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array post-addition:")
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


## Out-of-Bounds Indexing

In [14]:
try:
    np.arange(10)[11]
except Exception as e:
    print("Exception {}".format(e))

Exception index 11 is out of bounds for axis 0 with size 10


In [15]:
jnp.arange(10)[11]

DeviceArray(9, dtype=int32)

Note that due to this behavior jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error.

# Random Numbers

## RNGs and State

In [16]:
print(np.random.random())
print(np.random.random())
print(np.random.random())

0.008195634639773797
0.22403312940214648
0.864401510752375


In [17]:
np.random.seed(0)
rng_state = np.random.get_state()
print(rng_state)

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 1272276355, 3172048492,
       3267256201, 2332199830, 1975469449,  392443598, 1132453229,
       2900699076, 1998300999, 3847713992,  512669506, 1227792182,
       1629110240,  112303347, 2142631694, 3647635483, 1715036585,
       2508091258, 1355887243, 1884998310, 3906360088,  952450269,
       3647883368, 3962623343, 3077504981, 2023096077, 3791588343,
       3937487744, 3455116780, 1218485897, 1374508007, 2815569918,
       1367263917,  472908318, 2263147545, 1461547499, 4126813079,
       2383504810,   64750479, 2963140275, 1709368

In [18]:
_ = np.random.uniform()
rng_state = np.random.get_state()

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
    _ = np.random.uniform()
rng_state = np.random.get_state()

# Next call iterates the RNG state for a new batch of face "entropy"
_ = np.random.uniform()
rng_state = np.random.get_state()

The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

## JAX PRNG

In [2]:
from jax import random
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!

In [20]:
print(random.normal(key, shape=(1, )))
print(key)
# No no no!
print(random.normal(key, shape=(1, )))
print(key)

[-0.20584226]
[0 0]
[-0.20584226]
[0 0]


Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:

In [21]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [0 0]
    \---SPLIT --> new key    [4146024105  967050713]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]


In [22]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [4146024105  967050713]
    \---SPLIT --> new key    [2384771982 3928867769]
             \--> new subkey [1278412471 2182328957] --> normal [-0.58665055]


In [23]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
    print(random.normal(subkey, shape=(1,)))

[-0.37533438]
[0.98645043]
[0.14553197]


# Control Flow

## python control_flow + autodiff

In [24]:
def f(x):
    if x<3:
        return 3. * x ** 2
    else:
        return -4 * x
    
print(grad(f)(2.))
print(grad(f)(4.))

12.0
-4.0


# python control_flow + JIT

Using control flow with jit is more complicated, and by default it has more constraints.

In [25]:
# This works:
@jit
def f(x):
    for i in range(3):
        x = 2 * x
    return x

print(f(3))

24


In [26]:
# So does this:
@jit
def g(x):
    y = 0.
    for i in range(x.shape[0]):
        y = y + x[i]
    return y

print(g(jnp.array([1., 2., 3.])))

6.0


In [27]:
# But this doesn't, at least by default
@jit
def f(x):
    if x < 3:
        return 3. * x ** 2
    else:
        return -4 * x

# This will fail!
try:
    f(2)
except Exception as e:
    print("Exception {}".format(e))

Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at <ipython-input-27-dc2e5a3b9c1a>:2, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to f at <ipython-input-27-dc2e5a3b9c1a>:2, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)


In [28]:
def f(x):
    if x < 3:
        return 3. * x ** 2
    else:
        return -4 * x
    
%timeit f(2.)

f = jit(f, static_argnums=(0,))

%timeit f(2.)

132 ns ± 0.236 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
112 µs ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [29]:
def f(x, n):
    y = 0.
    for i in range(n):
        y = y + x[i]
    return y

%timeit f(jnp.array([2., 3., 4.]), 2)

f = jit(f, static_argnums=(1,))

%timeit f(jnp.array([2., 3., 4.]), 2)

2.77 ms ± 378 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
133 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## functions with argument-value dependent shapes

In [30]:
def example_fun(length, val):
    return jnp.ones((length,)) * val

# un-jit'd works fine
print(example_fun(5, 4))

bad_example_jit = jit(example_fun)
# this will fail:
try:
    print(bad_example_jit(10, 4))
except Exception as e:
    print("Exception {}".format(e))
    
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))

# first compile
print(good_example_jit(10, 4))

# recompile
print(good_example_jit(5, 4))

[4. 4. 4. 4. 4.]
Exception The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


## Structured control flow primitives

### lax.cond: differentiable

In [31]:
from jax import lax

operand = jnp.array([0.])
print(lax.cond(True, lambda x: x+1, lambda x: x-1, operand))

print(lax.cond(False, lambda x: x+1, lambda x: x-1, operand))

[1.]
[-1.]


### while_loop: fwd-mode differentible

In [32]:
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)

DeviceArray(10, dtype=int32)

### fori_loop

In [33]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)

DeviceArray(45, dtype=int32)

# NaNs

## Debugging NaNs

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:

In [34]:
import os

os.environ['JAX_DEBUG_NANS'] = 'True'

In [35]:
from jax.config import config

config.update("jax_debug_nans", True) # near the top of your main file

In [36]:
from jax.config import config

config.parse_flags_with_absl() 
# to your main file, then set the option using a command-line flag line --jax_debug_nans=True

In [4]:
import jax.numpy as jnp
jnp.divide(0., 0.)

DeviceArray(nan, dtype=float32)

# Double (64bit) precision

In [5]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype

dtype('float32')