New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default serialisation fails for BatchNorm
.
#166
Comments
Hmm. The issue here is that the model hasn't been called yet, so its extra state hasn't been set yet. Fixed in #167. |
Doing a forward pass on the deserialised model leads to Possibly the reason for this failure is same as in your comment above? def test_serialise_bn(getkey):
import jax
import jax.numpy as jnp
x = jnp.arange(12).reshape(4, 3)
net = eqx.nn.Sequential(
[
eqx.nn.Linear(3, 5, key=getkey()),
eqx.experimental.BatchNorm(5, axis_name="batch"),
]
)
# jax.vmap(net, axis_name="batch")(x) # A
eqx.tree_serialise_leaves('/tmp/net.eqx', net)
net_2 = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)
net_2 = eqx.tree_inference(net_2, True) # B
jax.vmap(net_2, axis_name="batch")(x) |
So that's definitely expected. If you haven't run the model yet then the batch statistics aren't available; i.e. you have to train your BatchNorm before you can use it at inference. I suppose we could provide dummy initial statistics for the BatchNorm. I don't know if that's better than raising an error though. Alternatively we could make the error message more descriptive. |
But it fails with A uncommented and B commented as well. So do you mean that
def test_serialise_bn(getkey):
import jax
import jax.numpy as jnp
import optax
x = jnp.arange(12).reshape(4, 3)
net = eqx.nn.Sequential(
[
eqx.nn.Linear(3, 5, key=getkey()),
eqx.experimental.BatchNorm(5, axis_name="batch"),
]
)
# Uncomment to save
# @eqx.filter_value_and_grad
# def compute_loss(model, x):
# output = jax.vmap(model, axis_name=('batch'))(x)
# return output.mean()
#
# # Important for efficiency whenever you use JAX: wrap everything into a single JIT region.
# @eqx.filter_jit
# def make_step(model, x, optimizer, opt_state):
# loss, grads = compute_loss(model, x)
# updates, opt_state = optimizer.update(grads, opt_state)
# model = eqx.apply_updates(model, updates)
# return loss, model, opt_state
#
# optimizer = optax.adam(learning_rate=0.01)
# opt_state = optimizer.init(eqx.filter(net, eqx.is_array))
# _, net, _ = make_step(net, x, optimizer, opt_state)
#
# eqx.tree_serialise_leaves('/tmp/net.eqx', net)
net_2 = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)
net_2 = eqx.tree_inference(net_2, True)
jax.vmap(net_2, axis_name="batch")(x) |
Right. Looks like the de/serialisation routines were implicitly assuming that the value behind the For now I've gone for the simple approach of generalising to handle arrays or tuples of arrays, which is all we need for BatchNorm. (And raise a |
Thanks for the fix! |
Hi, How can I use the |
A minor update: I can load and set StateIndex without a dummy pass. I simply load the weights for everything except the running stats. After that similar to how the deserialisation works for StateIndex, I set the values using the tree_map. Thanks a lot for fixing the de/serialisation!! |
Excellent. And btw you definitely shouldn't set Use |
Sorry to reopen this issue. I tried with saving a model mid-training and loading the checkpoint to resume it. It seems the behaviour is breaking on the fixed branch (#172). Sharing a small script to reproduce the behaviour. The script needs to be run Once with loss, grads = compute_loss(model, x, y, keys)
File "/tmp/equinox/equinox/grad.py", line 30, in __call__
return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
File "/tmp/equinox/equinox/grad.py", line 27, in fun_value_and_grad
return __self._fun(_x, *_args, **_kwargs)
File "adv.py", line 13, in compute_loss
logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
File "/tmp/equinox/equinox/nn/composed.py", line 129, in __call__
x = layer(x, key=key)
File "/tmp/equinox/equinox/experimental/batch_norm.py", line 161, in __call__
lambda: get_state(
File "/opt/conda/lib/python3.7/site-packages/jax/experimental/host_callback.py", line 1334, in _outside_call_jvp_rule
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import numpy as np
import equinox as eqx
LOAD = False
@eqx.filter_value_and_grad
def compute_loss(model, x, y, keys):
logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
one_hot_actual = jax.nn.one_hot(y, num_classes=5)
return optax.softmax_cross_entropy(logits, one_hot_actual).mean()
@eqx.filter_jit
def make_step(model, x, y, keys, optimizer, opt_state):
loss, grads = compute_loss(model, x, y, keys)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
net = eqx.nn.Sequential(
[
eqx.nn.Linear(3, 5, key=jrandom.PRNGKey(0)),
eqx.experimental.BatchNorm(5, axis_name='batch')
]
)
x = jnp.asarray(np.random.rand(10, 3))
y = jnp.asarray(np.random.randint(0, 9, 10))
key = jrandom.split(jrandom.PRNGKey(0), 10)
if LOAD:
net = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)
optimizer = optax.adam(learning_rate=0.1)
opt_state = optimizer.init(eqx.filter(net, eqx.is_array))
_, net, _ = make_step(net, x, y, key, optimizer, opt_state)
eqx.tree_serialise_leaves('/tmp/net.eqx', net) |
Hi,
The defaut serialisation fails when a model with
BatchNorm
is serialised. A small test script executed ondev
branch.with the error
The text was updated successfully, but these errors were encountered: