Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eqx.filter_{vmap,pmap}(out=...) not experimental #124

Merged
merged 2 commits into from Jul 5, 2022

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Jun 30, 2022

  1. Previously using a callable out parameter was experimental for
    filter_vmap -- because it monkey-patched JAX internals -- and
    unavailable for filter_pmap. It has now been updated to work for both,
    and using only the public JAX API. (Hurrah.)

  2. Drive-by: Added eqx.filter_eval_shape as it looked like this was
    going to useful as part of implementing the previous feature. In the
    end this wasn't the case, but we get a new feature anyway.

  3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped
    function whilst using a PyTree as its filter spec (fn).

Closes #115.

CC @jatentaki WDYT?

1. Previously using a callable `out` parameter was experimental for
   `filter_vmap` -- because it monkey-patched JAX internals -- and
   unavailable for `filter_pmap`. It has now been updated to work for both,
   and using only the public JAX API. (Hurrah.)

2. Drive-by: Added `eqx.filter_eval_shape` as it looked like this was
   going to useful as part of implementing the previous feature. In the
   end this wasn't the case, but we get a new feature anyway.

3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped
   function whilst using a PyTree as its filter spec (`fn`).
@jatentaki
Copy link
Contributor

Looks good, I just tested with my code and it does the job :) Interesting, I also thought filter_eval_shape would be a necessary building block towards the solution.

@jatentaki
Copy link
Contributor

Question: does filter_eval_shape really work with any python object (as per documentation), or only pytrees with potentially non-JAX leaves?

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Jul 1, 2022

Excellent, I'm glad that this works for you.

By definition all types are PyTrees -- potentially they're just stumps of a single non-JAX-object. So I don't think there's a distinction between "any Python object" and "pytress with potentially non-JAX leaves"?

@jatentaki
Copy link
Contributor

I guess, but it would not recurse into a non-pytree even if it holds jax arrays somewhere inside, which sounds like it may be implied by the wording. That's just a nit pick though.

@patrick-kidger
Copy link
Owner Author

Ah I see. So that's the same as passing a JAX array in via closure. And I think that should be totally fine / won't break.

@jatentaki
Copy link
Contributor

In case you're still waiting for some input from me, LGTM :)

@patrick-kidger
Copy link
Owner Author

Aha ty! Nope I just wasn't in any hurry to merge this -- I'm waiting on seeing how/if #126 plays out and then merging them both together.

@patrick-kidger patrick-kidger merged commit 7924841 into main Jul 5, 2022
@patrick-kidger patrick-kidger deleted the no-experimental-filter-vmap branch July 5, 2022 15:29
@patrick-kidger
Copy link
Owner Author

Merging this now as I've just discovered a need for filter_eval_shape in the upcoming version of Diffrax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Lack of out= breaks use of filter_pmap(fn) when output of fn involves non-JAX types
2 participants