New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations #238
Comments
Looks like this is related to #234 (see my last comment there). Essentially, after the first call to import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.experimental.BatchNorm(input_size=4, axis_name="batch")
x = jr.normal(dkey, (10, 4))
flat_model, treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))
jax.vmap(model, axis_name="batch")(x)
new_flat_model, new_treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))
diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), new_flat_model, flat_model) which also throws:
I'm just now trying to figure out exactly how |
Just say something I know,the state of |
So the stateful operations of Equinox (which BN uses under the hood) do some pretty magical things in order to work. In particular, they do mutate the underlying pytree. (Note that this isn't normally allowed - the stateful stuff is magical.) So I think that whilst expected, this isn't a bug. |
In the case of BN, they don't have to though, right? We know the underlying pytree at the time of In def __init__(self, inference: bool = False, state=None):
self._obj = _IndexObj()
self._version = _FixedInt(-1)
self._state = state
self.inference = inference and initialize the ...
self.state_index = StateIndex(inference=inference, state=(jnp.zeros((input_size,), dtype=jnp.float32), jnp.zeros((input_size,), dtype=jnp.float32)))
... I've not added this as a PR yet since it's tiny changes and I was hoping I could get #237 fixed at the same time, but can submit this on its own as well. |
Fair enough, yes. Two follow-up questions though: is there any reason why using BatchNorm layers in the way I described above might make my pipelined model give different outputs for the same inputs, even if I'm -in theory- not altering the model at all? If I remove these layers everything works as expected. Is this just a bug on my end or would you think something's up here? Secondly, what would be the correct way of filtering a model with BatchNorm to keep trainable parameters only? is_array and is_inexact_array seem to give different results for BatchNorm. |
Hi @geomlyd
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([
eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_combined = [model_pt1, model_pt2]
x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
prev_flattened_model = flattened_model
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
y1 = jax.vmap(model_pt1, axis_name="batch")(x)
y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
if(i >= 2):
diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
assert eqx.tree_equal(flattened_model, prev_flattened_model)
|
Thanks for the reply @uuirs, I get what's going on. After some experimentation, I think the behavior I'm talking about (the same model giving different outputs for the same inputs) arises from the momentum parameter of BatchNorm. This is somewhat independent of equinox, but does it make sense to normalize each batch by its own mean and variance, and only use the running statistics during inference? I assume this would then make the model yield the same output for the same input (which doesn't happen when changing, running statistics are used). Do people do that in practice? If so, is there a way to do it in equinox? |
Right, so you're probably seeing the value of a batch change due to the running statistics updating. As for your question -- there are numerous different flavours of batch norm. Equinox happens to implement one version; you've described another. IMO the one Equinox implements makes most sense, since it means the statistics between train and test are most similar. If you wanted one of the other BN variants then I'd suggest copying Equinox's BN implemented and adjusting it appropriately. |
OK, that makes sense! Sorry for the rather confusing issue, it looks like it all stemmed from the fact that I was only aware of one BatchNorm flavor. Thanks again for your input, I think the issue can be considered solved. |
Hi,
This might be something that is already known, or perhaps I'm not using the library as intended. Apologies in advance if that's the case. First some background info:
I'm writing code for a scenario that features a form of pipeline parallelism. I have a model, which I split in parts/modules, and each part is run on a different device. The results of each part are passed on to the next in a loop. The model features BatchNorm (I'm trying to implement some known results that use it, although I'm now aware that BatchNorm is finicky in Jax).
As a test case, I feed N batches of the exact same samples in the first N iterations, then do some updates on my model. I repeat this procedure with a new batch, which is fed repeatedly for the next N iterations. As a sanity check, in every N consecutive iterations, the model should output the same values. This is not the case, though, and I think BatchNorm might be the issue.
To debug, I thought I'd check whether the model's parameters change during these N iterations, by flattening it and comparing it to its previous version. However, I run into errors regarding "List arity mismatch". I have a very simplified example that exhibits this sort of behavior below. To simulate my use case, the second module/part is only run from the third iteration onward. Even for i = 1, the two model "versions" are not comparable (one was before running anything, the second after running the first module/part).
If I remove the BatchNorm layers there are no errors, which leads me to believe that the fact that it modifies its state is the problem. Am I using something wrong here? If not, how can I work around this, and what could possibly cause my model's output to be different for the same inputs?
The text was updated successfully, but these errors were encountered: