In [2]:
## Standard libraries
import os
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf') # For export
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

Using jax 0.4.27


In [4]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(a)
b = jnp.arange(6)
print(b)
print(b.__class__)
b_cpu = jax.device_get(b)
print(b_cpu.__class__)
print(jax.devices())

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[0 1 2 3 4 5]
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'numpy.ndarray'>
[CpuDevice(id=0)]


# Immutable tensors

(You can't change arrays via ordinary in-place substitution because this violates the intent of pure function)

In [5]:
b_new = b.at[0].set(1)
print('Original array:', b)
print('Changed array:', b_new)

Original array: [0 1 2 3 4 5]
Changed array: [1 1 2 3 4 5]


# Pseudorandom numbers

Because jax requires all functiosn to be pure, you can't really change a aseed in a PRNG generator in jax

In [6]:
rng = jax.random.PRNGKey(42)

In [30]:
# A non-desirable way of generating pseudo-random numbers...
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

# Typical random numbers in NumPy
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print('NumPy - Random number 1:', np_random_number_1)
print('NumPy - Random number 2:', np_random_number_2)

JAX - Random number 1: -0.18471177
JAX - Random number 2: -0.18471177
NumPy - Random number 1: 0.4967141530112327
NumPy - Random number 2: -0.13826430117118466


Using PRNG in jax requires splitting a PRNG state, where the input is the rng. You want to split the PRNG every time you need a pseudo-number, passing the rng into itself.

In [40]:
# Running this gives a different answer every time because you pass the rng to itself

rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
print(rng, subkey1, subkey2)
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)

[2913658946 2980543260] [ 906713395 3891066739] [ 724597107 1867688990]
JAX new - Random number 1: 0.72477084
JAX new - Random number 2: -0.98226595


# Function transformation

In [7]:
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

inp = jnp.arange(3, dtype=jnp.float32)
print('Input', inp)
print('Output', simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


In [8]:
jax.make_jaxpr(simple_graph)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 3.0
  [34m[22m[1min [39m[22m[22m(f,) }

In [65]:
global_list = []

# Invalid function with side-effect
def norm(x):
    global_list.append(x)
    x = x ** 2
    n = x.sum()
    n = jnp.sqrt(n)
    return ndd

jax.make_jaxpr(norm)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = integer_pow[y=2] a
    c:f32[] = reduce_sum[axes=(0,)] b
    d:f32[] = sqrt c
  in (d,) }

# Automatic differentiation

In [9]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


In [10]:
jax.make_jaxpr(grad_function)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = integer_pow[y=1] b
    e[35m:f32[3][39m = mul 2.0 d
    f[35m:f32[3][39m = add c 3.0
    g[35m:f32[][39m = reduce_sum[axes=(0,)] f
    _[35m:f32[][39m = div g 3.0
    h[35m:f32[][39m = div 1.0 3.0
    i[35m:f32[3][39m = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j[35m:f32[3][39m = mul i e
  [34m[22m[1min [39m[22m[22m(j,) }

In [11]:
jax.make_jaxpr(grad_function)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = integer_pow[y=1] b
    e[35m:f32[3][39m = mul 2.0 d
    f[35m:f32[3][39m = add c 3.0
    g[35m:f32[][39m = reduce_sum[axes=(0,)] f
    _[35m:f32[][39m = div g 3.0
    h[35m:f32[][39m = div 1.0 3.0
    i[35m:f32[3][39m = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j[35m:f32[3][39m = mul i e
  [34m[22m[1min [39m[22m[22m(j,) }

# Just-in-time compilation

In [12]:
jitted_function = jax.jit(simple_graph)

In [13]:
# Create a new random subkey for generating new random values
rng, normal_rng = jax.random.split(rng)
large_input = jax.random.normal(normal_rng, (1000,))
# Run the jitted function once to start compilation
_ = jitted_function(large_input)

In [14]:
%%timeit
simple_graph(large_input).block_until_ready()

153 µs ± 9.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
%%timeit
jitted_function(large_input).block_until_ready()

10.6 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
