## Imports

In [1]:
from flax import nnx
import jax.numpy as jnp
import pytest

## PRNG handling

For more information see:
- [Flax NNX - Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html#) for `nnx.Rngs`
- [Jax - random numbers - explicit random state](https://docs.jax.dev/en/latest/random-numbers.html#explicit-random-state) for `keys`

### Initialization

The RngStream `params` is used for initialization of `nnx.Linear`, `nnx.Conv`, `nnx.ConvTranspose` and `nnx.Embed`.

In [2]:
rngs = nnx.Rngs(0)
linear = nnx.Linear(1, 2, rngs=rngs)
nnx.display(linear)

#### `rng_collection`

`nnx.Dropout` have an `rng_collection` attribute, in order to find the specific RngStream (for Dropout: `dropout`).

`nnx.Dropout` initialization uses:
- the _dropout_ stream,
- the _default_ stream, if _dropout_ stream was not found,
- `None`, if _dropout_ and _default_ stream was not found or `rngs` was not provided.

in this order.

In [3]:
rngs1 = nnx.Rngs(0, params=1, dropout=2)  # default, params and dropout stream
dropout1 = nnx.Dropout(rate=0.5, rngs=rngs1)  # uses dropout stream
rngs2 = nnx.Rngs(0, params=1)  # default, params stream
dropout2 = nnx.Dropout(rate=0.5, rngs=rngs2)  # no dropout stream, fallback to default stream
rngs3 = None  # or not provided
dropout3 = nnx.Dropout(rate=0.5, rngs=rngs3)
nnx.display(dropout1, dropout2, dropout3)

Note the `rngs.tag`s

### _call_ time

The `nnx.Dropout` module also requires a random state, but rather at _call_ time than initialization to create a key, which is then used to create a mask.

If no random state is provided at call time, the random state falls back to the object's class attribute.

Furthermore, if no random state was provided at initialization it is `nnx.data(None)` and a ValueError is raised:

In [4]:
x = jnp.ones(1)
msg = """`deterministic` is False, but no `rngs` argument was provided to Dropout
          as either a __call__ argument or class attribute."""
with pytest.raises(ValueError, match=msg):
    dropout3(x)

The rngs provided at call time must be of one, of the following, types:
- `rnglib.Rngs`
- `rnglib.RngStream`
- `jax.Array`
- `None`

### Fallback

Flax NNX provides a `default` stream that can be used as a fallback when a stream is not found and can be called via `rngs()`. Here's an example:

In [5]:
rngs = nnx.Rngs(0, params=1)  # default and params stream

key1 = rngs.params()  # Call params stream.
key2 = rngs.dropout()  # Fallback to the default stream.
key3 = rngs()  # Call the default stream directly.

nnx.display(rngs)

Note the **RngCount** for the _params_ and _default_ streams.