-
-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lack of out=
breaks use of filter_pmap(fn)
when output of fn
involves non-JAX types
#115
Comments
I just realized this can be solved by marking the |
In particular, |
All of the above is expected behaviour. filter_vmapIn this case I want to be able to give strong commitments about the stability of Equinox, so any funny business like that isn't enabled by default. -- Side note -- In case you're curious/concerned, the full list of "funny business" is:
-- End side note -- FWIW I've had some thoughts on how this functionality might be accomplished without this slightly-concerning-looking monkey patching, so at some point this could graduate from being experimental to being the default. filter_pmapIn this case, the same monkey-patching trick doesn't work, as the internals of As a result I left this as a In particular the main use case I know of for filtering the output of (FWIW my thoughts on de-experimental-ifying non-static nonlinearitiesThis sometimes surprises people! See also this recent issue and a similar result of this discussed in the FAQ. In short, it would be wrong to assume that the nonlinearities are static because they might be learnt, and then we'd like the parameters to be part of the PyTree. Handling examples like this is actaully the whole point of filtering: we want to determine static-vs-not-static when we call JIT -- because in general we can't determine this where the PyTree is defined. And whilst activation functions are one of the most common examples of this, filtering of course extends to handle numerous other use cases as well: freezing some weights in a layer; determine which inputs to vectorise-vs-broadcast in a static_fieldSo as the above block emphasises, filtering is really the appropriate way to determine what's static and what's traced. (And equivalently for the other transforms: what's vectorised and what's broadcasted, what's differentiated and what's not differentiated, ...) Indeed, marking a field as always-static actually strictly reduces expressivity, because you could always have filtered it out instead. (The correct thing to do in 99% of cases.) So what's up with But... you should almost never need this functionality. |
Thanks for your reply, I think I understand your rationale for this design. Still, it sounds at odds with my mental picture of parallel training. This is my first attempt at that in JAX, but as I mentioned in the previous issue, I believe the way to train on multiple GPUs is to shard-replicate the model and optimizer state across GPUs, shard-split the training batch, insert What would be your approach to parallel training with Since making this issue I noticed that using _0_if_array = lambda v: 0 if eqx.is_array(v) else None
mask = jax.tree_map(_0_if_array, v_model)
eqx.filter_vmap(fn, out=mask)(v_model)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [61], in <cell line: 1>()
----> 1 eqx.filter_vmap(fn, out=mask)(v_model)
File ~/Programs/equinox/equinox/vmap_pmap.py:146, in _VmapWrapper.__call__(***failed resolving arguments***)
144 else: # `out` of type ResolvedAxisSpec
145 out_axes = _map_axes(__self._out)
--> 146 vmapd, nonvmapd = jax.vmap(
147 _fun_wrapper, in_axes=in_axes, out_axes=out_axes, **__self._vmapkwargs
148 )(__self._fun, bound.args, bound.kwargs)
149 return combine(vmapd, nonvmapd.value)
[... skipping hidden 5 frame]
File /opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/jax/_src/api_util.py:310, in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
306 else:
307 hint += (f" In particular, you're passing in a single argument which "
308 f"means that {name} might need to be wrapped in "
309 f"a singleton tuple.")
--> 310 raise ValueError(f"{name} specification must be a tree prefix of the "
311 f"corresponding value, got specification {axis_tree} "
312 f"for value tree {treedef}.{hint}") from None
313 axes = [None if a is proxy else a for a in axes]
314 assert len(axes) == treedef.num_leaves
ValueError: vmap out_axes specification must be a tree prefix of the corresponding value, got specification Model(
linear=Linear(weight=0, bias=0, in_features=3, out_features=3, use_bias=True)
) for value tree PyTreeDef((CustomNode(<class '__main__.Model'>[(('linear',), (), ())], [CustomNode(<class 'equinox.nn.linear.Linear'>[(('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (3, 3, True))], [*, *])]), CustomNode(<class 'equinox.compile_utils.Static'>[((), ('value',), (Model(
linear=Linear(
weight=None,
bias=None,
in_features=3,
out_features=3,
use_bias=True
)
),))], []))). |
More poking around: is it intended that here you don't return an extra |
As for the error you're seeing, and you comment about the missing |
The repro is here. In short, functions which return PyTrees involving non-JAX types break
filter_{vmap,pmap}
. Withvmap
this can be fixed by specifying theout=
kwarg, but it doesn't work forpmap
. Is this just a TODO or a limitation of the "networks as callable PyTrees" model that equinox takes?The text was updated successfully, but these errors were encountered: