Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default serialisation fails for BatchNorm. #166

Closed
paganpasta opened this issue Aug 3, 2022 · 10 comments · Fixed by #172
Closed

Default serialisation fails for BatchNorm. #166

paganpasta opened this issue Aug 3, 2022 · 10 comments · Fixed by #172
Labels
bug Something isn't working

Comments

@paganpasta
Copy link
Contributor

Hi,

The defaut serialisation fails when a model with BatchNorm is serialised. A small test script executed on dev branch.

def test_serialise_bn(getkey):
    net = eqx.nn.Sequential(
        [
            eqx.experimental.BatchNorm(3, axis_name="batch"),
        ]
    )

    eqx.tree_serialise_leaves('/tmp/net.eqx', net)

    assert True

with the error

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../equinox/serialisation.py:183: in tree_serialise_leaves
    jtu.tree_map(_serialise, filter_spec, pytree)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:181: in _serialise
    jtu.tree_map(__serialise, x, is_leaf=is_leaf)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:179: in __serialise
    spec(f, y)
../equinox/serialisation.py:50: in default_serialise_filter_spec
    value, _, _ = x.unsafe_get()
../equinox/experimental/stateful.py:112: in unsafe_get
    return _state_cache[self._obj]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <WeakKeyDictionary at 0x7fc6deff8d90>
key = <equinox.experimental.stateful._IndexObj object at 0x7fc6deec12b0>

    def __getitem__(self, key):
>       return self.data[ref(key)]
E       KeyError: <weakref at 0x7fc736f7e590; to '_IndexObj' at 0x7fc6deec12b0>

../../../miniconda3/envs/equinox/lib/python3.8/weakref.py:383: KeyError
@patrick-kidger
Copy link
Owner

Hmm. The issue here is that the model hasn't been called yet, so its extra state hasn't been set yet.

Fixed in #167.

@patrick-kidger patrick-kidger added the bug Something isn't working label Aug 3, 2022
@paganpasta
Copy link
Contributor Author

paganpasta commented Aug 4, 2022

Doing a forward pass on the deserialised model leads to jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: RuntimeError: Cannot get state before it has been set

Possibly the reason for this failure is same as in your comment above?
Example to reproduce it (Only commenting A and B works, Every other combination fails.)

def test_serialise_bn(getkey):
    import jax
    import jax.numpy as jnp
    x = jnp.arange(12).reshape(4, 3)

    net = eqx.nn.Sequential(
        [
            eqx.nn.Linear(3, 5, key=getkey()),
            eqx.experimental.BatchNorm(5, axis_name="batch"),
        ]
    )

    # jax.vmap(net, axis_name="batch")(x) # A
    eqx.tree_serialise_leaves('/tmp/net.eqx', net)

    net_2 = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)

    net_2 = eqx.tree_inference(net_2, True) # B
    jax.vmap(net_2, axis_name="batch")(x)

@patrick-kidger
Copy link
Owner

So that's definitely expected. If you haven't run the model yet then the batch statistics aren't available; i.e. you have to train your BatchNorm before you can use it at inference.

I suppose we could provide dummy initial statistics for the BatchNorm. I don't know if that's better than raising an error though. Alternatively we could make the error message more descriptive.

@paganpasta
Copy link
Contributor Author

paganpasta commented Aug 4, 2022

But it fails with A uncommented and B commented as well. So do you mean that net_2 needs a dummy forward before loading the weights?
Another extension I tried is

  1. training net for one step.
  2. Saving it.
  3. Loading it with net_2.
  4. Forward with and without tree_inference call
    This failed at deserialisation. XLARuntimeError but with a different message jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: TypeError: Callback func <jax.experimental.host_callback._CallbackWrapper object at 0x7f2f884660d0> should have returned a result with pytree PyTreeDef((*, *)) but returned PyTreeDef(*)
def test_serialise_bn(getkey):
    import jax
    import jax.numpy as jnp
    import optax

    x = jnp.arange(12).reshape(4, 3)

    net = eqx.nn.Sequential(
        [
            eqx.nn.Linear(3, 5, key=getkey()),
            eqx.experimental.BatchNorm(5, axis_name="batch"),
        ]
    )
    # Uncomment to save
    # @eqx.filter_value_and_grad
    # def compute_loss(model, x):
    #     output = jax.vmap(model, axis_name=('batch'))(x)
    #     return output.mean()
    #
    # # Important for efficiency whenever you use JAX: wrap everything into a single JIT region.
    # @eqx.filter_jit
    # def make_step(model, x, optimizer, opt_state):
    #     loss, grads = compute_loss(model, x)
    #     updates, opt_state = optimizer.update(grads, opt_state)
    #     model = eqx.apply_updates(model, updates)
    #     return loss, model, opt_state
    #
    # optimizer = optax.adam(learning_rate=0.01)
    # opt_state = optimizer.init(eqx.filter(net, eqx.is_array))
    # _, net, _ = make_step(net, x, optimizer, opt_state)
    #
    # eqx.tree_serialise_leaves('/tmp/net.eqx', net)

    net_2 = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)
    net_2 = eqx.tree_inference(net_2, True)
    jax.vmap(net_2, axis_name="batch")(x)

@patrick-kidger
Copy link
Owner

Right. Looks like the de/serialisation routines were implicitly assuming that the value behind the StateIndex was just an array, rather than a PyTree.

For now I've gone for the simple approach of generalising to handle arrays or tuples of arrays, which is all we need for BatchNorm. (And raise a NotImplementedError otherwise.) Stateful operations are already very edge-case so this should be fine for now. (Eventually I'm hoping that algebraic effects will mean that this whole API can be deprecated in favour of better alternatives in core JAX, hence why I'm happy not to worry too much about writing the complicated general case here.)

@paganpasta
Copy link
Contributor Author

Thanks for the fix!
I'll try it out at the earliest.

@paganpasta
Copy link
Contributor Author

Hi,

How can I use the running_mean and running_var loaded from saved checkpoint (Non eqx checkpoint)?
At the moment i do a dummy pass on the eqx_model to initialise the BatchNorm. Followed by replacing the weight, bias, state_index._state[0], state_index._state[1] with new values. However, during inference, state_index inference is still False and the latest cached state is of the dummy pass and not of the loaded checkpoint. Is there a way to force an update to the _state_cache to use loaded running_mean and running_var?

@paganpasta
Copy link
Contributor Author

A minor update: I can load and set StateIndex without a dummy pass. I simply load the weights for everything except the running stats. After that similar to how the deserialisation works for StateIndex, I set the values using the tree_map.

Thanks a lot for fixing the de/serialisation!!

@patrick-kidger
Copy link
Owner

Excellent. And btw you definitely shouldn't set _state. This is part of some deep magic to make inference mode work without the cost of looking the value up at runtime.

Use eqx.experimental.set_state if you ever want to modify the state manually.

@paganpasta
Copy link
Contributor Author

Sorry to reopen this issue.

I tried with saving a model mid-training and loading the checkpoint to resume it. It seems the behaviour is breaking on the fixed branch (#172).

Sharing a small script to reproduce the behaviour. The script needs to be run twice.

Once with LOAD=False and then with LOAD=True.
With LOAD=False, the script works as intended and the net is serialised to the disk.
With LOAD=True, I get the error

    loss, grads = compute_loss(model, x, y, keys)
  File "/tmp/equinox/equinox/grad.py", line 30, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/tmp/equinox/equinox/grad.py", line 27, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "adv.py", line 13, in compute_loss
    logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
  File "/tmp/equinox/equinox/nn/composed.py", line 129, in __call__
    x = layer(x, key=key)
  File "/tmp/equinox/equinox/experimental/batch_norm.py", line 161, in __call__
    lambda: get_state(
  File "/opt/conda/lib/python3.7/site-packages/jax/experimental/host_callback.py", line 1334, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax

import numpy as np
import equinox as eqx

LOAD = False

@eqx.filter_value_and_grad
def compute_loss(model, x, y, keys):
    logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
    one_hot_actual = jax.nn.one_hot(y, num_classes=5)
    return optax.softmax_cross_entropy(logits, one_hot_actual).mean()
        

@eqx.filter_jit
def make_step(model, x, y, keys, optimizer, opt_state):
    loss, grads = compute_loss(model, x, y, keys)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

net = eqx.nn.Sequential(
            [
                eqx.nn.Linear(3, 5, key=jrandom.PRNGKey(0)),
                eqx.experimental.BatchNorm(5, axis_name='batch')
            ]
        )

x = jnp.asarray(np.random.rand(10, 3))
y = jnp.asarray(np.random.randint(0, 9, 10))
key = jrandom.split(jrandom.PRNGKey(0), 10)
if LOAD:
    net = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)

optimizer = optax.adam(learning_rate=0.1)
opt_state = optimizer.init(eqx.filter(net, eqx.is_array))
_, net, _ = make_step(net, x, y, key, optimizer, opt_state)
eqx.tree_serialise_leaves('/tmp/net.eqx', net)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants