In [1]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
from flax import nnx


class Model(nnx.Module):
    def __init__(self, rngs):
        self.linear1 = nnx.Linear(2, 3, rngs=rngs)
        self.linear2 = nnx.Linear(3, 4, rngs=rngs)

    def __call__(self, x):
        return self.linear2(self.linear1(x))


x = jax.random.normal(jax.random.key(0), (1, 2))
y = jnp.ones((1, 4))

model = Model(nnx.Rngs(0))

In [3]:
from flax import nnx


def model_state_decay_mask(model):
    flat_state = nnx.state(model).flat_state()
    flat_mask = flat_state.copy()
    for key in flat_state.keys():
        flat_mask[key] = key[-1] not in ("bias")
    # return frozen_dict.freeze(nnx.State.from_flat_path(flat_mask))
    return nnx.State.from_flat_path(flat_mask)


model_state_decay_mask(model)

State({
  'linear1': {
    'bias': False,
    'kernel': True
  },
  'linear2': {
    'bias': False,
    'kernel': True
  }
})

In [10]:
import optax

tx = optax.adamw(1e-3, weight_decay=1e-4, mask=model_state_decay_mask(model))
state = nnx.Optimizer(model, tx)


def loss_fn(model):
    return ((model(x) - y) ** 2).mean()


loss_fn(model)
grads = nnx.grad(loss_fn)(state.model)
grads
state.update(grads)
loss_fn(model)

Array(1.6668038, dtype=float32)

State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=Array([ 0.6486942 , -2.2046442 ,  0.14181204], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.50907314,  1.7301297 , -0.11128926],
             [ 0.55557084, -1.8881562 ,  0.12145419]], dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=Array([-0.531473  , -0.71881384, -0.6530258 , -0.664232  ], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[ 0.21832745,  0.29528648,  0.26826096,  0.27286443],
             [ 0.24969177,  0.33770654,  0.3067986 ,  0.3120634 ],
             [-0.01679468, -0.02271469, -0.02063578, -0.0209899 ]],      dtype=float32)
    )
  }
})

Array(1.6540173, dtype=float32)

In [50]:
from flax.traverse_util import path_aware_map


def partition_fn(path, x):
    # print(path)
    if path[0][-1] in ("bias"):
        return "decay"
    else:
        return "no_decay"


path_aware_map(
    partition_fn,
    nnx.state(model).flat_state(),
)

nnx.State.from_flat_path(
    path_aware_map(
        partition_fn,
        nnx.state(model).flat_state(),
    )
)

{('linear1', 'bias'): 'decay',
 ('linear1', 'kernel'): 'no_decay',
 ('linear2', 'bias'): 'decay',
 ('linear2', 'kernel'): 'no_decay'}

State({
  'linear1': {
    'bias': decay,
    'kernel': no_decay
  },
  'linear2': {
    'bias': decay,
    'kernel': no_decay
  }
})

In [8]:
from flax.traverse_util import path_aware_map

# nnx.state(model).flat_state()


def partition_fn(path, x):
    # print(path)
    if path[0][-1] in ("bias"):
        return "decay"
    else:
        return "no_decay"


# param_partitions = flax.core.freeze(
#     path_aware_map(
#         partition_fn,
#         nnx.state(model).flat_state(),
#     )
# )

param_partitions = nnx.State.from_flat_path(
    path_aware_map(
        partition_fn,
        nnx.state(model).flat_state(),
    )
)


def get_optimizer(decay):
    return optax.adamw(learning_rate=1e-3, weight_decay=decay)


partition_optimizers = {
    "decay": get_optimizer(0.1),
    "no_decay": get_optimizer(0.0),
}

tx = optax.multi_transform(partition_optimizers, param_partitions)

state = nnx.Optimizer(model, tx)


def loss_fn(model):
    return ((model(x) - y) ** 2).mean()


loss_fn(model)
grads = nnx.grad(loss_fn)(state.model)
grads

# Would a label function in place of param partitions help?

Array(1.6796587, dtype=float32)

State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=Array([ 0.64893687, -2.2157297 ,  0.14647551], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.5092636,  1.7388293, -0.114949 ],
             [ 0.5557786, -1.8976502,  0.1254482]], dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=Array([-0.53430504, -0.7224043 , -0.65456724, -0.666221  ], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[ 0.21807964,  0.29485342,  0.26716533,  0.2719219 ],
             [ 0.25243354,  0.34130144,  0.30925167,  0.31475753],
             [-0.01829538, -0.02473617, -0.02241333, -0.02281238]],      dtype=float32)
    )
  }
})