# Why NNX?

## Intro
* Flax strengths
* Flax weaknesses


## NNX is Pythonic
* Example of building a Module

In [2]:
from flax.experimental import nnx
import jax
import jax.numpy as jnp

class Count(nnx.Variable): pass

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        self.din = din
        self.dout = dout
        key = ctx.make_rng("params")
        self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
        self.b = nnx.Param(jnp.zeros((dout,)))
        self.count = Count(0)  # track the number of calls

    def __call__(self, x) -> jax.Array:
        self.count += 1
        return x @ self.w + self.b

model = Linear(din=5, dout=2, ctx=nnx.context(0))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
x = jnp.ones((1, 5))
y = model(x)

print(f"{model.count=}")
print(f"{model.w=}")
print(f"{model.b=}")
print(f"{model=}")

model.count=1
model.w=Array([[0.0779959 , 0.8061936 ],
       [0.05617034, 0.55959475],
       [0.3948189 , 0.5856023 ],
       [0.82162833, 0.27394366],
       [0.07696676, 0.8982161 ]], dtype=float32)
model.b=Array([0., 0.], dtype=float32)
model=Linear(
  din=5,
  dout=2
)




## NNX is friendly for beginners
* Example of training in eager mode


In [6]:
import numpy as np

X = np.random.uniform(size=(1000, 1))
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1)) 

model = Linear(1, 1, ctx=nnx.context(0))

for step in range(500):
    idx = np.random.randint(0, 1000, size=(32,))
    x, y = X[idx], Y[idx]

    def loss_fn(model: Linear):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn, wrt=nnx.Param)(model)

    params = model.filter(nnx.Param)
    params = jax.tree_map(
        lambda p, g: p - 0.1 * g, params, grads
    )
    model.update_state(params)

    if step % 100 == 0:
        y_pred = model(X)
        loss = np.mean((y_pred - Y) ** 2)
        print(f"Step {step}: loss={loss:.4f}")

print(f"\n{model.w = }")
print(f"{model.b = }")

Step 0: loss=0.2734
Step 100: loss=0.0108
Step 200: loss=0.0105
Step 300: loss=0.0105
Step 400: loss=0.0107

model.w = Array([[0.8045632]], dtype=float32)
model.b = Array([0.39513135], dtype=float32)



## NNX is friendly for advanced users
* Example of manual scan over layer 



## Parameter surgery is intuitive
* Simple parameter surgery example

## 



## What about Pytree-based libraries?