In [4]:
from flax import nnx

In [5]:
class Net(nnx.Module):
    def __init__(self,din,dout, rngs):
        self.linear = nnx.Linear(din,dout, rngs=rngs)
    
    def __call__(self,x):
        return self.linear(x)

net = Net(3,4,nnx.Rngs(0))

nnx.display(net)

In [29]:
import jax 
import jax.numpy as jnp

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)

In [31]:
graphdef, state  = nnx.split(model)

In [35]:
nnx.display(graphdef, state)

In [45]:
x = jnp.arange(30).reshape(5,6)
layer = nnx.RMSNorm(6, rngs=nnx.Rngs(0))
nnx.state(layer)
nnx.display(layer)

In [46]:
y = layer(x)

In [47]:
y

Array([[0.        , 0.3302891 , 0.6605782 , 0.99086726, 1.3211564 ,
        1.6514455 ],
       [0.6920518 , 0.8073938 , 0.92273575, 1.0380777 , 1.1534197 ,
        1.2687616 ],
       [0.8219049 , 0.890397  , 0.95888907, 1.0273812 , 1.0958732 ,
        1.1643653 ],
       [0.8750175 , 0.9236296 , 0.9722417 , 1.0208538 , 1.0694659 ,
        1.118078  ],
       [0.90378547, 0.94144315, 0.9791009 , 1.0167586 , 1.0544163 ,
        1.092074  ]], dtype=float32)

In [49]:
x/ jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + 1e-6)

Array([[0.        , 0.3302891 , 0.6605782 , 0.9908673 , 1.3211564 ,
        1.6514455 ],
       [0.69205177, 0.8073937 , 0.92273575, 1.0380777 , 1.1534196 ,
        1.2687616 ],
       [0.82190496, 0.890397  , 0.9588891 , 1.0273812 , 1.0958732 ,
        1.1643654 ],
       [0.8750175 , 0.92362964, 0.9722417 , 1.0208538 , 1.0694659 ,
        1.118078  ],
       [0.9037854 , 0.94144315, 0.9791009 , 1.0167586 , 1.0544163 ,
        1.092074  ]], dtype=float32)

In [50]:
layer.scale

Param(
  value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)

In [53]:
param = nnx.state(layer).filter(nnx.Param)

State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})