New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm raises "TypeError: Expected a callable value, got inf" #234
Comments
rename |
A @uuirs points out, you've redefined the variable That said, it's a little odd that you're only hitting this when using |
Yeah it's odd that this happened in such a way and you're right -- I checked and with |
Had some time to debug this tonight. I think it's because the state of equinox/equinox/experimental/batch_norm.py Line 115 in 039f6bf
but once it calculates the running means and vars it sets it to a tuple of arrays: equinox/equinox/experimental/batch_norm.py Line 175 in 039f6bf
This changes the I'm assuming the fix is a one line change in |
Aren't these different state indices? |
Sorry, you're right, for initialization I meant to reference: equinox/equinox/experimental/stateful.py Line 107 in 039f6bf
To be fair, even then I don't understand some of the behaviour. A small working example of the bug above is: import jax
import jax.random as jrandom
import equinox as eqx
@eqx.filter_jit
def fun(bn, inp):
print("Compiling!")
return jax.vmap(bn, axis_name="batch")(inp)
def info(bn):
print(f"bn.state_index._state: {bn.state_index._state}") # prints None all the time
children, aux = bn.state_index.tree_flatten() # children should be dynamic_field_values and so children[0] should be the value of _state
print(f"children[0] of bn.state_index flatten: {children[0]}") # prints a tuple of arrays after 1st call to fun
if __name__=="__main__":
bn = eqx.experimental.BatchNorm(10, axis_name="batch")
inp = jrandom.normal(jrandom.PRNGKey(0), (32, 10))
info(bn)
fun(bn, inp)
info(bn)
fun(bn, inp)
info(bn)
fun(bn, inp) which outputs
So you can clearly see that at least when accessing I don't understand why equinox/equinox/experimental/stateful.py Line 183 in 039f6bf
but I guess it gets deleted somewhere in between the two prints by |
Closing as |
I'm getting weird errors with
BatchNorm
. One example is the code below, where ODE parameters are optimized and the drift is a neural network with someBatchNorm
layers. The error thrown isTypeError: Expected a callable value, got inf
when vmapping a loss function.I wasn't able to reduce this to something without
diffrax
-- double vmaps or vmapping a scan doesn't raise exceptions. I'm guessing this is an equinox issue since without batchnorm (bn=False
) there are no problems.The text was updated successfully, but these errors were encountered: