## Imports

In [1]:
from flax.nnx import Rngs, RngStream

## Rngs-handling

`nnx.Rngs` and `nnx.RngStream` have a `tag`:

#### `nnx.Rngs`

In [2]:
default_rngs = Rngs(1)
print(f"Rngs: \n{default_rngs}")
print(f"\n tag: {default_rngs.default.tag}")

Rngs: 
[38;2;79;201;177mRngs[0m[38;2;255;213;3m([0m[38;2;105;105;105m # RngState: 2 (12 B)[0m
  [38;2;156;220;254mdefault[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngStream[0m[38;2;255;213;3m([0m[38;2;105;105;105m # RngState: 2 (12 B)[0m
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m,
    [38;2;156;220;254mkey[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
      [0 1],
      [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m
    [38;2;255;213;3m)[0m,
    [38;2;156;220;254mcount[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngCount[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(0, dtype=uint32),
      [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144

#### `nnx.RngStream`

In [3]:
default_rng_stream = RngStream(1, tag="default")
print(f"\n RngStream: \n {default_rng_stream}")
print(f"\n tag: {default_rng_stream.tag}")


 RngStream: 
 [38;2;79;201;177mRngStream[0m[38;2;255;213;3m([0m[38;2;105;105;105m # RngState: 2 (12 B)[0m
  [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m,
  [38;2;156;220;254mkey[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
    [0 1],
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mcount[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngCount[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(0, dtype=uint32),
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m
  [38;2;255;213;3m)[0m
[38;2;255;213;3m)[0m

 tag: default


#### user input

If users are allowed to pass rngs, in the form of `rnglib.Rngs` and `rnglib.RngStream`, see [official flax nnx dropout documentation](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/nn/stochastic.html#Dropout), these tags need to be changed. The input of an `int` as `key seed` is also allowed. This can be done by the following functionality:

In [4]:
from flax import nnx
from flax.nnx import rnglib


def dropout_rngs(rngs: rnglib.Rngs | rnglib.RngStream | int = 1) -> nnx.Rngs:
    if isinstance(rngs, rnglib.Rngs):
        rngs_metadata = rngs.get_metadata()
        rngs = nnx.Rngs(dropout=rngs_metadata)
    if isinstance(rngs, rnglib.RngStream | int):
        rngs = nnx.Rngs(dropout=rngs)
    return rngs


def dropconnect_rngs(rngs: rnglib.Rngs | rnglib.RngStream | int = 1) -> nnx.Rngs:
    if isinstance(rngs, rnglib.Rngs):
        rngs_metadata = rngs.get_metadata()
        rngs = nnx.Rngs(dropconnect=rngs_metadata)
    if isinstance(rngs, rnglib.RngStream | int):
        rngs = nnx.Rngs(dropconnect=rngs)
    return rngs

Here's an example with a `key_seed`:

In [5]:
key_seed = 1
dropout_rngs_seed = dropout_rngs(key_seed)
dropconnect_rngs_seed = dropconnect_rngs(key_seed)
print(f"Rngs: \n {dropout_rngs_seed}")
print(f"\n tag: {dropout_rngs_seed.dropout.tag} \n")
print(f"Rngs: \n {dropconnect_rngs_seed}")
print(f"\n tag: {dropconnect_rngs_seed.dropconnect.tag}")

Rngs: 
 [38;2;79;201;177mRngs[0m[38;2;255;213;3m([0m[38;2;105;105;105m # RngState: 2 (12 B)[0m
  [38;2;156;220;254mdropout[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngStream[0m[38;2;255;213;3m([0m[38;2;105;105;105m # RngState: 2 (12 B)[0m
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'dropout'[0m,
    [38;2;156;220;254mkey[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
      [0 1],
      [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'dropout'[0m
    [38;2;255;213;3m)[0m,
    [38;2;156;220;254mcount[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngCount[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(0, dtype=uint32),
      [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;14

Assertion of all variations

In [6]:
p = 0.1
raw_rngs = Rngs(1)
rng_stream = RngStream(2, tag="default")
key_seed = 3
rngs_conversion = [raw_rngs, rng_stream, key_seed]
for i in rngs_conversion:
    assert dropout_rngs(i).dropout.tag == "dropout"  # noqa:S101
    assert dropconnect_rngs(i).dropconnect.tag == "dropconnect"  # noqa:S101

### This can allow later configuration of these `Rngs` within a layer with the appropriate `tag`, i.e. ensembling, but needs further investigation.

## flax `dropconnect` and `dropout` demonstration

#### dropconnect

In [7]:
from flax import nnx
from flax.nnx import Linear, Rngs, Sequential

from probly.transformation.dropconnect import dropconnect
from probly.transformation.dropconnect.flax import replace_flax_dropconnect

rngs = Rngs(1)
flax_linear = Linear(1, 2, rngs=rngs)
flax_sequential = Sequential(Linear(1, 2, rngs=rngs), Linear(2, 1, rngs=rngs))

# dropconnect call:
flax_dropconnect = dropconnect(flax_sequential)
print("dropconnect call:")
nnx.display(flax_dropconnect)

# replace_flax_dropconnect call:
flax_dropconnect_linear = replace_flax_dropconnect(flax_linear, p=0.1, rngs=rngs)
print("replace_flax_dropconnect call:")
nnx.display(flax_dropconnect_linear)

dropconnect call:


replace_flax_dropconnect call:


#### dropout

In [8]:
from flax import nnx
from flax.nnx import Linear, Rngs, Sequential

from probly.transformation.dropout import dropout
from probly.transformation.dropout.flax import prepend_flax_dropout

rngs = Rngs(1)
flax_linear = Linear(1, 2, rngs=rngs)
flax_sequential = Sequential(Linear(1, 2, rngs=rngs), Linear(2, 1, rngs=rngs))

# dropout call:
flax_dropout = dropout(flax_sequential)
print("dropout call:")
nnx.display(flax_dropout)

# preprend_flax_dropout call:
flax_dropout_linear = prepend_flax_dropout(flax_linear, p=0.1, rngs=rngs)
print("preprend_flax_dropout call:")
nnx.display(flax_dropout_linear)

dropout call:


preprend_flax_dropout call:


## Influence on Torch

The addition of `rngs: Any` and appropriate noqa-comments was necessary in `dropconnect\torch` and `dropout\torch`

`dropconnect`- and `dropout`-transformation calls are not influenced by these changes:

In [9]:
from torch import nn

from probly.transformation.dropconnect import dropconnect
from probly.transformation.dropout import dropout

torch_sequential = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 1))

torch_dropconnet = dropconnect(torch_sequential)
torch_dropout = dropout(torch_sequential)
print(torch_dropconnet, "\n", torch_dropout)

Sequential(
  (0): Linear(in_features=1, out_features=2, bias=True)
  (1): DropConnectLinear(in_features=2, out_features=1, bias=True)
) 
 Sequential(
  (0): Linear(in_features=1, out_features=2, bias=True)
  (1_0): Dropout(p=0.25, inplace=False)
  (1_1): Linear(in_features=2, out_features=1, bias=True)
)


Note that direct calls of `replace_torch_dropconnect` and `prepend_torch_dropout` need: `rngs=None`

In [10]:
from torch import nn

from probly.transformation.dropconnect.torch import replace_torch_dropconnect
from probly.transformation.dropout.torch import prepend_torch_dropout

torch_linear = nn.Linear(1, 2)

torch_dropconnect_linear = replace_torch_dropconnect(torch_linear, p=0.1, rngs=None)
torch_dropout_linear = prepend_torch_dropout(torch_linear, p=0.1, rngs=None)
print(torch_dropconnect_linear, "\n", torch_dropout_linear)

DropConnectLinear(in_features=1, out_features=2, bias=True) 
 Sequential(
  (0): Dropout(p=0.1, inplace=False)
  (1): Linear(in_features=1, out_features=2, bias=True)
)
