# 04 - Pseudo Random numbers in JAX

PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some state that is carried over from a sample to the next.

Let's look into how pseudo random numbers (PRNG) works in JAX and also compare them to Numpy. JAX has tried to stay compatible to numpy in most of the cases, but random number generation is a notable exception.

## Random numbers in Numpy

Let's look into this first, then we can set aside the context for why JAX PRNGs are differently designed.

PRNGs are natively supported in numpy by `numpy.random` module.

In [1]:
import numpy as np

In `numpy`, the PRNG is seeded based on a global state. This can be set to a deterministic initial condition using `random.seed(SEED)`.

In [2]:
np.random.seed(36)

In fact, the state can be inspected.

In [3]:
def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(len(str(full_random_state)))
  print(str(full_random_state)[:500], '...')

print_truncated_random_state()

8403
('MT19937', array([        36,  823087669,  696044523, 3737794234, 2820642817,
       2833617972, 1415470932, 1100159376, 1991159613, 1767259829,
       1568739598, 3245720054,   32809653, 2139754102, 3430310913,
       2588444377, 3267716983, 1080747221, 1349368758, 3401645638,
       1444560461, 3351592977, 3097094448, 2866987729,  494848087,
       2862389612, 3085330048, 2152727397, 4248073919, 2288088137,
       3350987957, 1420640749, 2115391036, 2388034354, 1252422162,
       2839293090,  ...


Each call to the `random` function updates the global state. Eww !

In [4]:
np.random.seed(23)

print_truncated_random_state()

_ = np.random.uniform()

print_truncated_random_state()

8403
('MT19937', array([        23, 3031259156, 3173090992, 1445815869, 1732473264,
       2638910426,  577956926, 3044434557,  937677603, 2194999256,
       3459629836, 3389127670,   38896821, 2267101302, 2457688530,
        355778335, 3857710667, 2059131513, 2370791786,   26557979,
       2497377211, 2631870738,  812958566, 4201426261, 4210911494,
       2656115730, 1651810922, 3621541714, 2749352977,   67546780,
       3227017130,  397291212, 3864502940, 1291010012, 2217768787,
       2251320344,  ...
8401
('MT19937', array([3522410815, 1382385284, 3929362074,   89751983, 1112872297,
       3468941588, 1499619521, 1355273773,   75167798, 1269066014,
        527286388, 2177645083, 4278068079, 2034671876, 1876867660,
       3028544392, 1394522484,  217078226, 3444565392, 3482965672,
       3497498513,  518145367, 2574711433, 2728720787,  214702893,
       3157849866,   55286010, 1769700016, 1797375436, 3031781153,
       1878523871, 1975386855, 3598775579, 4134989655, 1925354563,
    

Numpy API allows for both scalar or a vector of random numbers to be sampled in a single function call.

In [5]:
np.random.seed(23)
print(np.random.uniform(size=3))

[0.51729788 0.9469626  0.76545976]


Most importantly, NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:

In [6]:
np.random.seed(14)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(14)
print("all at once: ", np.random.uniform(size=3))

individually: [0.51394334 0.77316505 0.87042769]
all at once:  [0.51394334 0.77316505 0.87042769]


## Random numbers in JAX

Random numbers in JAX is significantly different in an ode to Functional thinking. Moreover, the RNG design in Numpy makes certain desirable properties in JAX harder to achieve:

- reproducible,
- parallelizable,
- vectorisable.

Let's first look into some of the problems associated with global random state in numpy

The code below sums two scalar sampled from a uniform distribution.

In [7]:
import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

1.9791922366721637


Now, this function can only be `reproducible` if we assume a specific order of execution for `bar` and `baz`

This is not a problem in Numpy+Python, but becomes problematic in JAX when we want to be able to jit and parallelize functions that don't depend on each other.

To avoid this issue, JAX doesn't consume a global state, instead random functions explicitly consume state, which is referred to as `key`. Hence, the concept of global `seed` and `state` instead is replace by a per-function call `key`

In [8]:
from jax import random

key = random.PRNGKey(36)

print(key)

[ 0 36]


A key is simply an array of shape (2,)

Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated.

In [9]:
print(random.normal(key))
print(random.normal(key))
print(random.normal(key))
print(random.normal(key))

-0.5886091
-0.5886091
-0.5886091
-0.5886091


### Never reuse keys 

⛔️ Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.

In order to generate different and independent samples, the key must be `split()` whenever you want to call a random function:

In [10]:
print("old key", key)
new_key, subkey = random.split(key)
del key  # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
del subkey  # The subkey is also discarded after use.

old key [ 0 36]
    \---SPLIT --> new key    [ 929388900 1617983450]
             \--> new subkey [1033240276 1828762053] --> normal 0.82549185


**The crucial point is that you never use the same PRNGKey twice.**

It's not super relevant which part of the split is called a key and which part is called the subkey, they're all pseudorandom numbers of equal status, the naming is just for the sake of convention, subkeys are immediately consumed while keys are used for further generation.

Using this convention below, one can immediately discard the old key automatically

In [11]:
key = random.PRNGKey(36)
key

DeviceArray([ 0, 36], dtype=uint32)

In [12]:
key, subkey = random.split(key)
key

array([ 929388900, 1617983450], dtype=uint32)

`split` can create as many keys as needed, not just 2.

In [13]:
key, *twenty_five_subkeys = random.split(key, num=26)

In [14]:
len(twenty_five_subkeys)

25

One another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee. JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware.

In [15]:
import jax.numpy as jnp

key = random.PRNGKey(71)
subkeys = random.split(key, 3)
sequence = jnp.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.PRNGKey(71)
print("all at once: ", random.normal(key, shape=(3,)))

individually: [ 1.1654526  -0.22058551  1.3822567 ]
all at once:  [-0.26694784 -0.9594315  -1.9316671 ]


This is all for today. Tomorrow, we'll look into working with nested data structures in JAX.