# Generate random numbers

Generating random numbers in JAX needs the programmer to be conscious, not drunk. 

This is going to look painful in the beginning. 

But in the end we will see this is super helpful in controlling the randomness flow. 

Think about how to generate random samples in numpy, for example, 

```python
np.random.seed(999)
samples1 = np.random.randn(5)
samples2 = np.random.uniform(5)
samples3 = np.random.gamma(5)
```

The results are reproducible under a fixed random seed.

Can we do

```python
jnp.random.seed(999)
...
```

?

No, we have to manually play with random keys by yourself, like in C/C++/Rust.

In [None]:
# The equivalence to the Numpy code

import jax

key = jax.random.PRNGKey(999)
samples1 = jax.random.normal(key, shape=(5, ))

key, _ = jax.random.split(key)
samples2 = jax.random.uniform(key, shape=(5, ))

key, _ = jax.random.split(key)
samples3 = jax.random.gamma(key, shape=(5, ), a=1.)

We may think explicitly splitting random keys is a fuss/verbose, but this is **standard** for asynchronous/parallel/vectorisable programmes. 

```python
np.random.seed(999)
samples1 = np.random.randn(5)
samples2 = np.random.uniform(5)
samples3 = np.random.gamma(5)
```

Why the results are reproducible? 

Because: 1) the random seed is fixed. 2) the order of execution is fixed. 

If `samples1` `samples2` `samples3` are executed in parallel, can you guarantee the results are the same? No, because you will not know the order of execution. This essentially introduces another randomness (from hardware level) that is hard to control. 

In [None]:
key = jax.random.PRNGKey(999)
print(key)

# The randomness under this splitted key is independent of the previous key
key, subkey = jax.random.split(key)
print(key, subkey)

key, subkey = jax.random.split(key)
print(key, subkey)

keys = jax.random.split(key, num=5)
print(keys)

In practice, you can just ignore `subkey` which is used to track randomness.

`key, _ = jax.random.split(key)`

# Example

Imagine that we have an algorithm and we woule like to test it on a synthetic model for 100 **independent** Monte Carlo runs then average the results. 

```python
num_mcs = 100
for i in range(num_mcs)
    np.random.seed(i)
    data = generate_data()
    result = my_algorithm(data)
```

The implementation above is a common mistake, as seeds 1, 2, ... 100 are **not independent**. How to make the randomness in the loop independent to each other? I actually don't know how to do so in Numpy.

In JAX, it's convenient:

```python
num_mcs = 100
key = jax.random.PRNGKey(1)
for i in range(num_mcs)
    key, _ = jax.random.split(key)
    data = generate_data(key)
    result = my_algorithm(data)
```

# Exercise

Generate two "random positive definite matrices of size 6" then add them.

Numpy implementation

```python
np.random.seed(999)

rand = np.random.randn(6)
psd_matrix_1 = np.outer(rand, rand) + np.eye(6)

rand = np.random.randn(6)
psd_matrix_2 = np.outer(rand, rand) + np.eye(6)

print(np.linalg.eigh(psd_matrix_1 + psd_matrix_2)[0])
```

```python
import jax.numpy as jnp

key = ?
rand = ?
psd_matrix_1 = jnp.outer(rand, rand) + jnp.eye(6)

key, _ = ?
rand = ?
psd_matrix_2 = jnp.outer(rand, rand) + jnp.eye(6)

print(jnp.linalg.eigh(psd_matrix_1 + psd_matrix_2)[0])
```

## Solution

In [None]:
import jax.numpy as jnp

key = jax.random.PRNGKey(999)
rand = jax.random.normal(key, (6, ))
psd_matrix_1 = jnp.outer(rand, rand) + jnp.eye(6)

key, _ = jax.random.split(key)
rand = jax.random.normal(key, (6, ))
psd_matrix_2 = jnp.outer(rand, rand) + jnp.eye(6)

print(jnp.linalg.eigh(psd_matrix_1 + psd_matrix_2)[0])