New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
eqx.filter_{vmap,pmap}(out=...) not experimental #124
Conversation
1. Previously using a callable `out` parameter was experimental for `filter_vmap` -- because it monkey-patched JAX internals -- and unavailable for `filter_pmap`. It has now been updated to work for both, and using only the public JAX API. (Hurrah.) 2. Drive-by: Added `eqx.filter_eval_shape` as it looked like this was going to useful as part of implementing the previous feature. In the end this wasn't the case, but we get a new feature anyway. 3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped function whilst using a PyTree as its filter spec (`fn`).
Looks good, I just tested with my code and it does the job :) Interesting, I also thought |
Question: does |
Excellent, I'm glad that this works for you. By definition all types are PyTrees -- potentially they're just stumps of a single non-JAX-object. So I don't think there's a distinction between "any Python object" and "pytress with potentially non-JAX leaves"? |
I guess, but it would not recurse into a non-pytree even if it holds jax arrays somewhere inside, which sounds like it may be implied by the wording. That's just a nit pick though. |
Ah I see. So that's the same as passing a JAX array in via closure. And I think that should be totally fine / won't break. |
In case you're still waiting for some input from me, LGTM :) |
Aha ty! Nope I just wasn't in any hurry to merge this -- I'm waiting on seeing how/if #126 plays out and then merging them both together. |
Merging this now as I've just discovered a need for |
Previously using a callable
out
parameter was experimental forfilter_vmap
-- because it monkey-patched JAX internals -- andunavailable for
filter_pmap
. It has now been updated to work for both,and using only the public JAX API. (Hurrah.)
Drive-by: Added
eqx.filter_eval_shape
as it looked like this wasgoing to useful as part of implementing the previous feature. In the
end this wasn't the case, but we get a new feature anyway.
Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped
function whilst using a PyTree as its filter spec (
fn
).Closes #115.
CC @jatentaki WDYT?