Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations #238

Closed
geomlyd opened this issue Nov 25, 2022 · 9 comments
Labels
question User queries

Comments

@geomlyd
Copy link

geomlyd commented Nov 25, 2022

Hi,
This might be something that is already known, or perhaps I'm not using the library as intended. Apologies in advance if that's the case. First some background info:

I'm writing code for a scenario that features a form of pipeline parallelism. I have a model, which I split in parts/modules, and each part is run on a different device. The results of each part are passed on to the next in a loop. The model features BatchNorm (I'm trying to implement some known results that use it, although I'm now aware that BatchNorm is finicky in Jax).

As a test case, I feed N batches of the exact same samples in the first N iterations, then do some updates on my model. I repeat this procedure with a new batch, which is fed repeatedly for the next N iterations. As a sanity check, in every N consecutive iterations, the model should output the same values. This is not the case, though, and I think BatchNorm might be the issue.

To debug, I thought I'd check whether the model's parameters change during these N iterations, by flattening it and comparing it to its previous version. However, I run into errors regarding "List arity mismatch". I have a very simplified example that exhibits this sort of behavior below. To simulate my use case, the second module/part is only run from the third iteration onward. Even for i = 1, the two model "versions" are not comparable (one was before running anything, the second after running the first module/part).

If I remove the BatchNorm layers there are no errors, which leads me to believe that the fact that it modifies its state is the problem. Am I using something wrong here? If not, how can I work around this, and what could possibly cause my model's output to be different for the same inputs?

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([   
    eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

model_combined = [model_pt1, model_pt2]

x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
    prev_flattened_model = flattened_model
    flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
    
    diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
    y1 = jax.vmap(model_pt1, axis_name="batch")(x)
    if(i >= 2):
        y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
@ciupakabra
Copy link

Looks like this is related to #234 (see my last comment there). Essentially, after the first call to BatchNorm the treedef of the BatchNorm module changes, leaving it with more leaves than before and so the two lists of leaves cannot be compared / substracted. A smaller working example than the one above is:

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.experimental.BatchNorm(input_size=4, axis_name="batch")

x = jr.normal(dkey, (10, 4))

flat_model, treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))

jax.vmap(model, axis_name="batch")(x)

new_flat_model, new_treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))

diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), new_flat_model, flat_model)

which also throws:

Traceback (most recent call last):
  File "/Users/andriusovsianas/repos/test/test.py", line 18, in <module>
    diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), new_flat_model, flat_model)
  File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: List arity mismatch: 2 != 4; list: [DeviceArray([1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32)].

I'm just now trying to figure out exactly how BatchNorm is implemented and what would be the fix but the maintainers will probably find it faster.

@uuirs
Copy link
Contributor

uuirs commented Nov 25, 2022

Just say something I know,the state of BatchNorm is completed after first data feed, which need the dtype and the structure of input(which can be any pytree). So this looks expected. After the first call, the treedef should always be the same.

@patrick-kidger
Copy link
Owner

So the stateful operations of Equinox (which BN uses under the hood) do some pretty magical things in order to work. In particular, they do mutate the underlying pytree. (Note that this isn't normally allowed - the stateful stuff is magical.) So I think that whilst expected, this isn't a bug.

@patrick-kidger patrick-kidger added the question User queries label Nov 26, 2022
@ciupakabra
Copy link

In particular, they do mutate the underlying pytree

In the case of BN, they don't have to though, right? We know the underlying pytree at the time of __init__. The following changes fixes this and the compilation twice issue in #234

In stateful.py let the constructor initialize the state for the StateIndex object:

    def __init__(self, inference: bool = False, state=None):
        self._obj = _IndexObj()
        self._version = _FixedInt(-1)
        self._state = state
        self.inference = inference

and initialize the state_index in batch_norm.py correctly:

...
        self.state_index = StateIndex(inference=inference, state=(jnp.zeros((input_size,), dtype=jnp.float32), jnp.zeros((input_size,), dtype=jnp.float32)))
...

I've not added this as a PR yet since it's tiny changes and I was hoping I could get #237 fixed at the same time, but can submit this on its own as well.

@geomlyd
Copy link
Author

geomlyd commented Nov 28, 2022

So the stateful operations of Equinox (which BN uses under the hood) do some pretty magical things in order to work. In particular, they do mutate the underlying pytree. (Note that this isn't normally allowed - the stateful stuff is magical.) So I think that whilst expected, this isn't a bug.

Fair enough, yes. Two follow-up questions though: is there any reason why using BatchNorm layers in the way I described above might make my pipelined model give different outputs for the same inputs, even if I'm -in theory- not altering the model at all? If I remove these layers everything works as expected. Is this just a bug on my end or would you think something's up here?

Secondly, what would be the correct way of filtering a model with BatchNorm to keep trainable parameters only? is_array and is_inexact_array seem to give different results for BatchNorm.

@uuirs
Copy link
Contributor

uuirs commented Nov 28, 2022

Hi @geomlyd

  1. Just try this, the reason is when you call model_pt2 triggers another modification.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([   
    eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

model_combined = [model_pt1, model_pt2]

x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
    prev_flattened_model = flattened_model
    flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
    
    y1 = jax.vmap(model_pt1, axis_name="batch")(x)
    y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
    if(i >= 2):
        diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
        assert eqx.tree_equal(flattened_model, prev_flattened_model)
  1. No need to filter BatchNorm as it already has set no gradient for its running statistics by jax.lax.stop_gradient. Usually just is_array.

@geomlyd
Copy link
Author

geomlyd commented Nov 29, 2022

Thanks for the reply @uuirs, I get what's going on. After some experimentation, I think the behavior I'm talking about (the same model giving different outputs for the same inputs) arises from the momentum parameter of BatchNorm.

This is somewhat independent of equinox, but does it make sense to normalize each batch by its own mean and variance, and only use the running statistics during inference? I assume this would then make the model yield the same output for the same input (which doesn't happen when changing, running statistics are used). Do people do that in practice? If so, is there a way to do it in equinox?

@patrick-kidger
Copy link
Owner

Right, so you're probably seeing the value of a batch change due to the running statistics updating.

As for your question -- there are numerous different flavours of batch norm. Equinox happens to implement one version; you've described another.

IMO the one Equinox implements makes most sense, since it means the statistics between train and test are most similar. If you wanted one of the other BN variants then I'd suggest copying Equinox's BN implemented and adjusting it appropriately.

@geomlyd
Copy link
Author

geomlyd commented Nov 29, 2022

OK, that makes sense! Sorry for the rather confusing issue, it looks like it all stemmed from the fact that I was only aware of one BatchNorm flavor. Thanks again for your input, I think the issue can be considered solved.

@geomlyd geomlyd closed this as completed Nov 29, 2022
@geomlyd geomlyd reopened this Nov 29, 2022
@geomlyd geomlyd closed this as completed Nov 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants