Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lack of out= breaks use of filter_pmap(fn) when output of fn involves non-JAX types #115

Closed
jatentaki opened this issue Jun 15, 2022 · 6 comments · Fixed by #124
Closed
Labels
feature New feature next Higher-priority items

Comments

@jatentaki
Copy link
Contributor

jatentaki commented Jun 15, 2022

The repro is here. In short, functions which return PyTrees involving non-JAX types break filter_{vmap,pmap}. With vmap this can be fixed by specifying the out= kwarg, but it doesn't work for pmap. Is this just a TODO or a limitation of the "networks as callable PyTrees" model that equinox takes?

@jatentaki
Copy link
Contributor Author

jatentaki commented Jun 15, 2022

I just realized this can be solved by marking the non_jax with equinox.static_field(); however this is a technique I only learned from reading the source, I don't think it's discussed in the docs?

@jatentaki
Copy link
Contributor Author

In particular, pmap is broken for eqx.nn.MLP due to nonlinearities not having been declared as static.

@patrick-kidger
Copy link
Owner

All of the above is expected behaviour.

filter_vmap

In this case out=0 is the default. Using a callable for out is experimental behaviour, because it relies on monkey-patching JAX internals.

I want to be able to give strong commitments about the stability of Equinox, so any funny business like that isn't enabled by default.

-- Side note --

In case you're curious/concerned, the full list of "funny business" is:

  • callable out for filter_vmap;
  • eqx.experimental.{StateIndex,get_state,set_state,BatchNorm,SpectralNorm};
  • eqx.tree_pformat, which is also used if eqx.Module.__repr__. This is really just about pretty-printing strings, and if that part of JAX ever gets changed it'd be easy enough to just copy over the one-file pretty-printer they built, so this gets an exception from being marked experimental.

-- End side note --

FWIW I've had some thoughts on how this functionality might be accomplished without this slightly-concerning-looking monkey patching, so at some point this could graduate from being experimental to being the default.

filter_pmap

In this case, the same monkey-patching trick doesn't work, as the internals of jax.pmap are very different to the internals of jax.vmap.

As a result I left this as a NotImplementedError. If you have a use-case for filtering the output of a pmap I'd be curious to know what it is? To my knowledge most use-cases for pmap are things like parallelising neural network training (etc.etc.), for which the output is always just a JAX array.

In particular the main use case I know of for filtering the output of vmap is creating model ensembles, which isn't something that would ever be necessary for a pmap.

(FWIW my thoughts on de-experimental-ifying filter_vmap should apply here too.)

non-static nonlinearities

This sometimes surprises people! See also this recent issue and a similar result of this discussed in the FAQ. In short, it would be wrong to assume that the nonlinearities are static because they might be learnt, and then we'd like the parameters to be part of the PyTree.

Handling examples like this is actaully the whole point of filtering: we want to determine static-vs-not-static when we call JIT -- because in general we can't determine this where the PyTree is defined.

And whilst activation functions are one of the most common examples of this, filtering of course extends to handle numerous other use cases as well: freezing some weights in a layer; determine which inputs to vectorise-vs-broadcast in a vmap; etc. etc.

static_field

So as the above block emphasises, filtering is really the appropriate way to determine what's static and what's traced. (And equivalently for the other transforms: what's vectorised and what's broadcasted, what's differentiated and what's not differentiated, ...)

Indeed, marking a field as always-static actually strictly reduces expressivity, because you could always have filtered it out instead. (The correct thing to do in 99% of cases.)

So what's up with static_field? Well, mostly it's just an internal tool: it's part of how filtering is implemented, and as a user you should essentially never need it. Nonetheless this kind of discussion comes up every now and again, and a few advanced users would like to be able to mark fields as static themselves. If nothing else I think it's bad for a library/language/framework to possess privileged operations (in this case filtering) that couldn't be replicated in user-space.

But... you should almost never need this functionality.

@jatentaki
Copy link
Contributor Author

Thanks for your reply, I think I understand your rationale for this design. Still, it sounds at odds with my mental picture of parallel training. This is my first attempt at that in JAX, but as I mentioned in the previous issue, I believe the way to train on multiple GPUs is to shard-replicate the model and optimizer state across GPUs, shard-split the training batch, insert grad = jax.lax.pmean(grad, axis_name='device') inside the update function and then pmap it. This means the example I linked above only differs from ensambling models by the pmean operations inside, which is obviously orthogonal to the issue we're discussing here. Therefore my use-case is indeed similar to ensambling models and since my update returns an entire neural network (with all non-array leaves it may have) I need to filter that.

What would be your approach to parallel training with pmap? To pmap-reduce the gradient computation, perform model update on a single GPU and then broadcast the updated model? I'm not very confident with equinox but I have the impression that I would also run into that kind of problems there.

Since making this issue I noticed that using static_field indeed breaks some things downstream, so I experimented with pre-computing a non-callable out= value. This should be possible since my output signature is known ahead of time, but for some reason JAX complains about structure mismatch...

_0_if_array = lambda v: 0 if eqx.is_array(v) else None
mask = jax.tree_map(_0_if_array, v_model)
eqx.filter_vmap(fn, out=mask)(v_model)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [61], in <cell line: 1>()
----> 1 eqx.filter_vmap(fn, out=mask)(v_model)

File ~/Programs/equinox/equinox/vmap_pmap.py:146, in _VmapWrapper.__call__(***failed resolving arguments***)
    144 else:  # `out` of type ResolvedAxisSpec
    145     out_axes = _map_axes(__self._out)
--> 146 vmapd, nonvmapd = jax.vmap(
    147     _fun_wrapper, in_axes=in_axes, out_axes=out_axes, **__self._vmapkwargs
    148 )(__self._fun, bound.args, bound.kwargs)
    149 return combine(vmapd, nonvmapd.value)

    [... skipping hidden 5 frame]

File /opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/jax/_src/api_util.py:310, in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
    306       else:
    307         hint += (f" In particular, you're passing in a single argument which "
    308                  f"means that {name} might need to be wrapped in "
    309                  f"a singleton tuple.")
--> 310   raise ValueError(f"{name} specification must be a tree prefix of the "
    311                    f"corresponding value, got specification {axis_tree} "
    312                    f"for value tree {treedef}.{hint}") from None
    313 axes = [None if a is proxy else a for a in axes]
    314 assert len(axes) == treedef.num_leaves

ValueError: vmap out_axes specification must be a tree prefix of the corresponding value, got specification Model(
  linear=Linear(weight=0, bias=0, in_features=3, out_features=3, use_bias=True)
) for value tree PyTreeDef((CustomNode(<class '__main__.Model'>[(('linear',), (), ())], [CustomNode(<class 'equinox.nn.linear.Linear'>[(('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (3, 3, True))], [*, *])]), CustomNode(<class 'equinox.compile_utils.Static'>[((), ('value',), (Model(
  linear=Linear(
    weight=None,
    bias=None,
    in_features=3,
    out_features=3,
    use_bias=True
  )
),))], []))).

@jatentaki
Copy link
Contributor Author

More poking around: is it intended that here you don't return an extra None, to broadcast the Static(nonvmapd) part of _fun_wrapper?

@patrick-kidger
Copy link
Owner

filter_pmap: ah, I see what you're getting at. Good point - that's also a reasonable way to do parallel training, so ideally filter_pmap should support this. I'll definitely have a look at switching things up to make these use cases work.

As for the error you're seeing, and you comment about the missing None: I think you're correct; there's a missing None and that's what raising the error you're seeing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature next Higher-priority items
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants