Skip to content

Commit 68d86dc

Browse files
authored
Ensure compatibility NNX 0.11.2 (#2067)
* ensure compatibility * rm xfail
1 parent 445f548 commit 68d86dc

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

numpyro/contrib/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def apply_fn(params, *call_args, **call_kwargs):
498498
if mutable_holder:
499499
nnx.replace_by_pure_dict(mutable_state, mutable_holder["state"])
500500

501-
model = nnx.merge(graph_def, params_state, mutable_state)
501+
model = nnx.merge(graph_def, params_state, mutable_state, copy=True)
502502

503503
model_call = model(*call_args, **call_kwargs)
504504

test/contrib/test_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,6 @@ def nnx_model_eager(x, y):
385385
@pytest.mark.parametrize(
386386
argnames="batchnorm", argvalues=[True, False], ids=["batchnorm", "no_batchnorm"]
387387
)
388-
@pytest.mark.xfail(
389-
reason="Temporary marking to pass CI. Bug fixed in https://github.com/pyro-ppl/numpyro/pull/2067"
390-
)
391388
def test_nnx_state_dropout_smoke(dropout, batchnorm):
392389
from flax import nnx
393390

0 commit comments

Comments
 (0)