<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Jax_tutorial_(5)_Pseudo_Random_Numbers_in_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Random numbers in NumPy

In NumPy, pseudo-random number generation is based on a global `state`

In [1]:
import numpy as np
np.random.seed(0) # random.seed(SEED)

In [2]:
def print_truncated_random_state():
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], "...")

In [3]:
print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


The `state` is updated by each call to a random function:

In [4]:
np.random.seed(0)

print_truncated_random_state()

_ = np.random.uniform()

print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...


In [6]:
np.random.seed(0)
print(np.random.uniform(size=3))

[0.5488135  0.71518937 0.60276338]


NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same psuedo-random sequences:

In [8]:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))

individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]


## Random Numbers in JAX

Why it's different? NumPy's PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:
1. reproducible
2. parallelizable
3. vectorizable

In [9]:
import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

1.9791922366721637


Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.

So, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key`.

In [10]:
from jax import random
key = random.key(42)

print(key) # a single key is an array of scalar shape () and key element type

Array((), dtype=key<fry>) overlaying:
[ 0 42]


'Random key' is essentially just another word for 'random seed'. However, any call of a random function in JAX requires a key to be specified.

In [11]:
print(random.normal(key))
print(random.normal(key))

-0.18471177
-0.18471177


**Rule of thumb**
 - Never reuse keys unless you want identical values

In [12]:
print('old key', key)
new_key, subkey = random.split(key)
del key # old key must be discarded
normal_sample = random.normal(subkey)

print(r"    \--SPLIT --> new key   ", new_key)
print(r"            \--> new subkey", subkey, '--> normal', normal_sample)
del subkey

key = new_key

old key Array((), dtype=key<fry>) overlaying:
[ 0 42]
    \--SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[2465931498 3679230171]
            \--> new subkey Array((), dtype=key<fry>) overlaying:
[255383827 267815257] --> normal 1.3694694


In [None]:
# concise way to write the above code
key, subkey = random.split(key)

In [13]:
key, *forty_two_subkeys = random.split(key, num=43)

In [15]:
len(forty_two_subkeys)

42

## Note
As in NumPy, JAX's random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware.

In [16]:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print(sequence)

key = random.key(42)
print(random.normal(key, shape=(3, )))

[-0.04838832  0.10796154 -1.2226542 ]
[ 0.18693547 -1.2806505  -1.5593132 ]
