This notebook follows jax official website: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

# 1. JAX vs. NumPy

JAX and NumPy are very similar, except that **JAX arrays are always immutable**

In [1]:
import jax
import numpy as np
import jax.numpy as jnp

In [2]:
jax.default_backend()

'gpu'

In [3]:
x_np = np.linspace(0, 10, 1000)
x_jnp = jnp.linspace(0, 10, 1000)

In [4]:
print(type(x_np))
print(type(x_jnp))

<class 'numpy.ndarray'>
<class 'jaxlib.xla_extension.ArrayImpl'>


In [5]:
# Numpy: mutable array
x = np.arange(10)
x[0] = 10
print(x)

[10  1  2  3  4  5  6  7  8  9]


In [7]:
# JAX: immutable array
x = jnp.arange(10)
# x[0] = 10

In [8]:
y = x.at[0].set(10)
print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]


# 2. JAX API Layering

`jax.numpy`: is a high-level wrapper that provides a familiar interface \
`jax.lax`: is a lower-level API that is stricter and often more powerful

In [9]:
import jax.numpy as jnp
import jax.lax as lax

In [10]:
# jax.numpy API implicitly promotes mixed types
jnp.add(1, 1.0)

Array(2., dtype=float32, weak_type=True)

In [11]:
# lax.add(1, 1.0) <- this won't working

In [12]:
lax.add(jnp.float32(1), 1.0)

Array(2., dtype=float32)

In [13]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [14]:
# In lax, the convolution should be like:
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(jnp.float32),
    y.reshape(1, 1, 10),
    window_strides=(1, ),
    padding=[(len(y) - 1, len(y) - 1)]
)
result[0][0]

Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

# 3. JIT or not

- By default JAX executes operations one at a time, in sequence.
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

In [15]:
import jax.numpy as jnp
from jax import jit

In [16]:
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

norm_complied = jit(norm)

In [17]:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_complied(X), atol=1E-6)

True

In [19]:
%timeit norm(X).block_until_ready()
%timeit norm_complied(X).block_until_ready()

1.06 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1 ms ± 81.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


`jax.jit` does have limitations: in particular, it requires **all arrays to have static shapes**.

In [20]:
def get_negatives(x):
    return x[x < 0.5]
x = jnp.array(np.random.rand(10))
# print(x)
get_negatives(x)

Array([0.34902555, 0.30862594, 0.16273628, 0.08442569, 0.19423719],      dtype=float32)

In [21]:
# jit(get_negatives(x)) <- this won't working

# 4. JIT mechanics: tracing and static variables

In [22]:
@jit
def f(x, y):
    print("Running f():")
    print(f"\tx = {x}")
    print(f"\ty = {y}")
    result = (x + 1) @ (y + 1)
    print(f"\tresult = {result}")
    return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
	x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
	y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
	result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([ 3.9697716,  4.905651 , 10.853258 ], dtype=float32)

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

In [23]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

Array([1.1132462, 8.516753 , 3.0157707], dtype=float32)

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

In [24]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    return -x if neg else x

f(1, True)

Array(-1, dtype=int32, weak_type=True)

In [25]:
f(1, False)

Array(1, dtype=int32, weak_type=True)

# 5. Static vs Traced Operations

- Just as values can be either static or traced, operations can be static or traced.

- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.

- Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced.

In [26]:
@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
# f(x) <- this won't working so far

In [27]:
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


Notice that although `x` is traced, `x.shape` is a static value. However, when we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static).

A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

In [28]:
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

Array([1., 1., 1., 1., 1., 1.], dtype=float32)