New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Buffer Donation in filter_jit #235
Conversation
donate_kwargs, and donate_fn in filter_jit; Adds Tests.
This is a great feature to have. This has prompted me to think about making some larger-scale changes to Equinox. So besides simply accepting this PR, I'd be curious to know what you think about each of the following other options.
For the sake of API elegance I quite like the approach of donating all arrays and static'ing all non-arrays, but that might be a mildly breaking change for some small number of use cases. Other musings:
Since you've taken the time to implement donations for |
Thanks for your input! Sorry for late reply. 1
In this implementation, all donated pytree are partitioned from dynamic pytree. As mentioned, it only works with parameters marked as traced, so it doesn't seems like it will be a donated static arg. a test: import jax
import jax.numpy as jnp
import equinox as eqx
@eqx.filter_jit(donate_args=(False,True)) # True on static
def f(x, y):
return x+1
x = jnp.zeros((100,))
y = 10
closed_jaxpr = eqx.filter_make_jaxpr(f)(x,y)[0]
xla_call = closed_jaxpr.eqns[0]
assert xla_call.params['donated_invars'] == (False,)
I did consider adding a enum(static, dynamic, donate), and adding chain_partition or something to return several pytree. But also, it seems to work that just making donated partition have a lower priority than the original partition and based on dynamic pytree. I think usually we just need to donate model without any additional specs(args, kwargs, etc.). I personally prefer From this point of view, I would like to keep these 2
I totally agree, if the default behavior can be changed easily via Some small issues:
3
I think it really makes things easier. Usually I'm use 4
I prefer to choose (a), just similar to the case above about type wrapper, I like to determine this when defining jitting, no when calling jitted. IMO, if big changes are allowed, I'd like to remove all the filtered argument and keep donate_* in jit and pmap, and add a type wrapper for users to specify argument as dynamic. |
These are some really interesting thoughts.
I'm not sure it will. I can believe that passing array in-and-out will prevent them from getting DCE'd, but (a) isn't that already happening independent of this change, and (b) wouldn't this change actually improve things here, because we would be donating and thus reusing the input buffer?
Good point here. We could possibly suppress these warnings.
That's an interesting use-case. How frequently does that happen? It would be easy enough to make this happen with a wrapper function instead, though: @eqx.filter_jit
def foo_impl(x, y):
...
def foo(x, y):
x, y = jax.tree_util.tree_map(jax.numpy.asarray, (x, y))
return foo_impl(x, y)
Again, I think this could be handled with a wrapper function: @eqx.filter_jit # donates args by default
def foo_impl(x, y):
...
def foo(x, y):
x = jnp.copy(x)
return foo_impl(x, y)
I think this is what All in all, I'm thinking of perhaps introducing a simple API
And for cases which can be handled with simple wrapper functions, we could say that that is the recommended way to approach them. WDYT? Am I missing a use-case that the above simplified API wouldn't be able to support? |
Tests fixes;
Refactored, seems to be ok. |
Hi @patrick-kidger, how does it look now? I've checked the examples on colab and it seems to be ok. |
(I'm at NeurIPS at the moment, I'll give a review next week :) ) |
Thanks, have a nice day! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks excellent. I've gone through; my comments are just small nits.
Other thoughts:
- Is there anywhere in the documentation that needs updating to reflect this change? I don't think so but I thought I'd double-check with you?
- We should think about similarly updating the other filtered transformations. Don't feel obliged to do this if you don't want to; if not then I'll get that done at some point and then do a new release. (In separate PRs, either way.)
Thank you for making this change happen!
Thank you for double-checking with me. > <
I would like to do this, but am afraid of not being able to cover all aspects of these changes in this PR. |
Alright, merged! Thank you for the PR, this is a great simplification to have. It's on the If you are happy to handle the other filtered transformations -- in other PRs -- then I'd be very happy to accept them. I'm thinking that we could:
|
Thank you for your patience and support! I'll try to submit them next week. |
* filter_custom_jvp now supports keyword arguments * added link to eqxvision * Support Buffer Donation in filter_jit (#235) * Adds donate_default, donate_args, donate_kwargs, and donate_fn in filter_jit; Adds Tests. * filter_jit with buffer donation V2; Tests fixes; * simplify filter_jit; example fixes; * Fix review comments * tweak docs (#242) * add str2jax test * Added ONNX export; finalise_jaxpr, nontraceable (#243) * Fixed up the finalise transformation to handle higher-order primitives (#244) * feat: Add PReLU activaion (#249) * Update activations.py * Started to add a checkpointed while loop * while fwd * at_set * no more at * Added checkpointed while loop. * Fix Black * Many doc improvements. Exposed a few new functions. New functions: - eqx.module_update_wrapper - eqx.tree_pprint - eqx.internal.tree_pp * Added combine(..., is_leaf=...) * Silenced most test warnings * Minor wart fixed in Module * Onboarding docs made more concise. (Less repetition of the word 'PyTree'!) * Added docstrings * filter_vmap(out=...) crash fix * Wrapper modules now preserve __name__, __module__ etc. when flattened and unflattened * Now guard against non-ShapedArray abstract values * Tweaks * Simplified filtered transformation APIs. This is to (a) make these operations more usable, (b) simplify new-user onboarding, and (c) improve maintability / bus-factor. The affected operations are `filter_{grad,value_and_grad,vmap,pmap}`. Also: - Added `equinox.if_array` and `equinox.internal.if_mapped`. - `eqx.filter_{*}` docs now generally explain what they do directly instead of referring back to JAX. * version bump * typos * Faster hashable partition/combine * Added advanced tricks to docs * Added nicer error message for inconsistent axis size * Various tweaks (all the best commit messages are descriptive) * filter_closure_convert no longer has jaxprs as static * Added while loop speed tests (incomplete). Switched default to filter_jit(donate=none). Moved all nondifferentiable, nonbatchable etc into a single file. * Buffers in checkpointed_while_loop + other stuff - checkpointed_while_loop now supports a new `eqxi.Buffer`, which allows for outputting results without saving them into the checkpoints. - added speed tests for checkpointed_while_loop - finalise_jaxpr now traces through custom_jvp, xla_call and stop_gradient, to aid its use as a debugging tool. - Switched to 88-length line cutoff in flake8. - filter_closure_convert now demands that you call it with arguments that look like its example arguments. * Renamed tree_pp -> pp * Minor bugfixes for checkpointed_while_loop with max_steps==0 * Added tags to buffers so we can distinguish multilevel checkpointed whiel loops * Working CWL? * Nested CWL now working! * Bugfix for CWL(buffers=<returns nontrivial pytrees>) * Moved eqxi.checkpointed_while_loop -> eqxi.while_loop, which also supports lax or bounded modes. * Fixed issue with transpose'd noinline * Removed old xla.lower_fun usage * Update versions * vmap/pmap in_axes can be now be a dictionary * - Updated to Python 3.8 so renamed __self->self etc. - filter_custom_vjp now checks that all non-vjp_args are not differentiated. * doc tweaks * Tweak test tolerance * Fixed default checkpoints being wrong --------- Co-authored-by: uuirs <3000cl@gmail.com> Co-authored-by: Enver Menadjiev <33788423+enver1323@users.noreply.github.com>
Try to close #3
Hi @patrick-kidger
Jax supports buffer donation on CPU after jaxlib >0.3.22, it is necessary for many in-place updation.
I've drafted a proposal in here.
Details:
donate_default
,donate_args
,donate_kwargs
, anddonate_fn
for control of buffer donation;Considerations:
More(not in this PR):
Supports buffer donation in filter_pmap