Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Buffer Donation in filter_jit #235

Merged
merged 5 commits into from Dec 7, 2022
Merged

Support Buffer Donation in filter_jit #235

merged 5 commits into from Dec 7, 2022

Conversation

uuirs
Copy link
Contributor

@uuirs uuirs commented Nov 23, 2022

Try to close #3

Hi @patrick-kidger

Jax supports buffer donation on CPU after jaxlib >0.3.22, it is necessary for many in-place updation.

I've drafted a proposal in here.

Details:

  • Adds donate_default, donate_args, donate_kwargs, and donate_fn for control of buffer donation;
  • Adds Tests.

Considerations:

  • It would be trivial to make these arugments work only with parameters marked as traced.
  • We need the ability to control buffer donation in function, args and kwargs, which is similar to handling Tracing.

More(not in this PR):
Supports buffer donation in filter_pmap

donate_kwargs, and donate_fn in filter_jit;
Adds Tests.
@patrick-kidger
Copy link
Owner

This is a great feature to have.

This has prompted me to think about making some larger-scale changes to Equinox. So besides simply accepting this PR, I'd be curious to know what you think about each of the following other options.

  1. Right now we specify a spec based around the dichotomy of True and False, for array-vs-non-array. Should we extend this to a trichotomy True, False, donate? Since a donated argument is necessarily also traced. Otherwise, the API proposed in this PR makes it possible to have a donated static arg, which is meaningless. (donate would be some special sentinel value.)

  2. Should we make it so that all arrays are donated by default, unless specified otherwise? (So a trichotomy of True, False, no_donate.) This is probably a silent performance improvement for most users, and will be a loud error (i.e. not a silent error) for those rare cases in which an array really is used after JIT'ing.

  3. Should we make it so that all arrays are always traced, and all non-arrays are always static? Essentially, simplify the filter_jit API by removing a lot of its complexity. I think it's really rare to want to deviate from the default. (Something I didn't appreciate when I first wrote it -- if I could start from scratch then I think I would simply not give it any filtering arguments at all.)

  4. If we take option 3: do we either (a) keep arguments that specify whether to donate an array, or (b) forcibly donate, and simply expect a user to perform a manual copy (before the filter_jit call) in the rare case they want to keep a copy around?

For the sake of API elegance I quite like the approach of donating all arrays and static'ing all non-arrays, but that might be a mildly breaking change for some small number of use cases.

Other musings:

  • I'd be very happy to delete the old-style filter_spec argument either way, that's long since deprecated.

  • Should we consider removing the filter arguments from all the other filtered transformations? Once again it's very rare to need a non-default for any of these, and in these cases then it can easily be replicated using eqx.partition and eqx.combine. We then get the benefits of a much simpler API throughout all of Equinox: nice docs, easier new user onboarding, code maintainability, etc. Indeed several of the more unusual filtered transformations that I added later (filter_{make_jaxpr, closure_convert, vjp, ...}) already don't take any filter arguments at all.

Since you've taken the time to implement donations for filter_jit, I'm guessing you may have a stake in the answers to some of these questions. WDYT?

@uuirs
Copy link
Contributor Author

uuirs commented Nov 24, 2022

Thanks for your input! Sorry for late reply.

1

the API proposed in this PR makes it possible to have a donated static arg

In this implementation, all donated pytree are partitioned from dynamic pytree. As mentioned, it only works with parameters marked as traced, so it doesn't seems like it will be a donated static arg.

a test:

import jax
import jax.numpy as jnp
import equinox as eqx
@eqx.filter_jit(donate_args=(False,True)) # True on static
def f(x, y):
  return x+1
x = jnp.zeros((100,))
y = 10
closed_jaxpr = eqx.filter_make_jaxpr(f)(x,y)[0]
xla_call = closed_jaxpr.eqns[0]
assert xla_call.params['donated_invars'] == (False,)

Should we extend this to a trichotomy True, False, donate?

I did consider adding a enum(static, dynamic, donate), and adding chain_partition or something to return several pytree.

But also, it seems to work that just making donated partition have a lower priority than the original partition and based on dynamic pytree. I think usually we just need to donate model without any additional specs(args, kwargs, etc.).

I personally prefer
eqx.filter_jit(fun, donate_args=(False, False, True,)) (args[0:3] of fun all can be a normal module)
eqx.filter_jit(fun, donate_kwargs=dict(model=True)) (model of fun can be a normal module)
over
eqx.filter_jit(fun, args=(False, True, eqx.donate)) (I don't care if args[0:2] are static or dynamic, and usually would not use like this)
eqx.filter_jit(fun, args=(eqx.is_array, eqx.is_array, eqx.is_array_donated)) (So it might become like this, something I guess would be)
eqx.filter_jit(fun, kwargs=dict(model=eqx.is_array_donated)) (works)

From this point of view, I would like to keep these donate_* and remove the other arguments(kwargs, args, fn, etc.)

2

Should we make it so that all arrays are donated by default, unless specified otherwise?

I totally agree, if the default behavior can be changed easily via donate_default=False or something.

Some small issues:

  1. Usually the pytree in input is mixed(dynamic part, static part, donated part, something to reuse, etc.), because we need to return those arrays that need to be reused, for convenience, we may end up returning the entire pytree. Then there are some DCE in jaxpr such as pruning unused input(a related resolved issue, XLA creates unnecessary copies of the function arguments google/jax#5145), now we return all inputs which may make these DCE not work properly.
  2. There may be some annoying warning when compiling, complaining that some donated buffers were not usable.

3

Should we make it so that all arrays are always traced, and all non-arrays are always static?

I think it really makes things easier. Usually I'm use filter_jit without any arugment and it works fine. But sometimes I need it to make sure the paramters are treated as dynamic, which is easier than to checking when calling. many existing python functions do not return jnp arrays. We might add another wrapper to just avoid wrong static input in case.

4

(a) keep arguments that specify whether to donate an array, or (b) forcibly donate, and simply expect a user to perform a manual copy

I prefer to choose (a), just similar to the case above about type wrapper, I like to determine this when defining jitting, no when calling jitted.
In a rare case like large device arrays without ownership(which means I can't easily assign back the reference ), copying is expensive.

IMO, if big changes are allowed, I'd like to remove all the filtered argument and keep donate_* in jit and pmap, and add a type wrapper for users to specify argument as dynamic.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 26, 2022

These are some really interesting thoughts.

which may make these DCE not work properly.

I'm not sure it will. I can believe that passing array in-and-out will prevent them from getting DCE'd, but (a) isn't that already happening independent of this change, and (b) wouldn't this change actually improve things here, because we would be donating and thus reusing the input buffer?

There may be some annoying warning when compiling

Good point here. We could possibly suppress these warnings.

But sometimes I need it to make sure the paramters are treated as dynamic, which is easier than to checking when calling. many existing python functions do not return jnp arrays.

That's an interesting use-case. How frequently does that happen? It would be easy enough to make this happen with a wrapper function instead, though:

@eqx.filter_jit
def foo_impl(x, y):
    ...

def foo(x, y):
    x, y = jax.tree_util.tree_map(jax.numpy.asarray, (x, y))
    return foo_impl(x, y)

I like to determine this when defining jitting, no when calling jitted.

Again, I think this could be handled with a wrapper function:

@eqx.filter_jit  # donates args by default
def foo_impl(x, y):
    ...

def foo(x, y):
    x = jnp.copy(x)
    return foo_impl(x, y)

add a type wrapper for users to specify argument as dynamic.

I think this is what jnp.array already is.


All in all, I'm thinking of perhaps introducing a simple API filter_jit(fn, donate), where we can have either:

  • donate='arrays': the default, donate all arrays and suppress all warnings about unused buffers;
  • donate='warn': as above, but don't suppress unused buffer warnings;
  • donate='none': no donation (current behaviour).

And for cases which can be handled with simple wrapper functions, we could say that that is the recommended way to approach them.

WDYT? Am I missing a use-case that the above simplified API wouldn't be able to support?

@uuirs
Copy link
Contributor Author

uuirs commented Nov 27, 2022

Refactored, seems to be ok.
Todo: documentation and example fixes.

@uuirs
Copy link
Contributor Author

uuirs commented Dec 2, 2022

Hi @patrick-kidger, how does it look now? I've checked the examples on colab and it seems to be ok.
Is there anything else I haven't considered?

@patrick-kidger
Copy link
Owner

(I'm at NeurIPS at the moment, I'll give a review next week :) )

@uuirs
Copy link
Contributor Author

uuirs commented Dec 2, 2022

Thanks, have a nice day!

@patrick-kidger patrick-kidger changed the base branch from main to dev December 5, 2022 19:49
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks excellent. I've gone through; my comments are just small nits.

Other thoughts:

  • Is there anywhere in the documentation that needs updating to reflect this change? I don't think so but I thought I'd double-check with you?
  • We should think about similarly updating the other filtered transformations. Don't feel obliged to do this if you don't want to; if not then I'll get that done at some point and then do a new release. (In separate PRs, either way.)

Thank you for making this change happen!

equinox/jit.py Outdated Show resolved Hide resolved
equinox/jit.py Show resolved Hide resolved
equinox/jit.py Show resolved Hide resolved
examples/frozen_layer.ipynb Outdated Show resolved Hide resolved
tests/test_jit.py Outdated Show resolved Hide resolved
tests/test_jit.py Outdated Show resolved Hide resolved
tests/test_jit.py Outdated Show resolved Hide resolved
tests/test_jit.py Outdated Show resolved Hide resolved
@uuirs
Copy link
Contributor Author

uuirs commented Dec 6, 2022

Thank you for double-checking with me. > <

We should think about similarly updating the other filtered transformations.

I would like to do this, but am afraid of not being able to cover all aspects of these changes in this PR.

@patrick-kidger patrick-kidger merged commit f9ee2ae into patrick-kidger:dev Dec 7, 2022
@patrick-kidger
Copy link
Owner

Alright, merged! Thank you for the PR, this is a great simplification to have. It's on the dev branch for now -- once we've unified things (with the other filtered transformations) then I'll merge it and do a new release.

If you are happy to handle the other filtered transformations -- in other PRs -- then I'd be very happy to accept them. I'm thinking that we could:

  • Remove the arg parameter from filter_{grad, value_and_grad}.
  • Perform the same simplification to filter_pmap as we've done here.
  • Figure out if there's a way to simplify the filter_{vmap,pmap} interfaces, since they still remain quite complicated.

@uuirs uuirs deleted the donate_args_support branch December 8, 2022 00:15
@uuirs
Copy link
Contributor Author

uuirs commented Dec 8, 2022

Thank you for your patience and support! I'll try to submit them next week.

patrick-kidger added a commit that referenced this pull request Feb 20, 2023
* filter_custom_jvp now supports keyword arguments

* added link to eqxvision

* Support Buffer Donation in filter_jit (#235)

* Adds donate_default, donate_args,
donate_kwargs, and donate_fn in filter_jit;
Adds Tests.

* filter_jit with buffer donation V2;
Tests fixes;

* simplify filter_jit; example fixes;

* Fix review comments

* tweak docs (#242)

* add str2jax test

* Added ONNX export; finalise_jaxpr, nontraceable (#243)

* Fixed up the finalise transformation to handle higher-order primitives (#244)

* feat: Add PReLU activaion (#249)

* Update activations.py

* Started to add a checkpointed while loop

* while fwd

* at_set

* no more at

* Added checkpointed while loop.

* Fix Black

* Many doc improvements. Exposed a few new functions.

New functions:
- eqx.module_update_wrapper
- eqx.tree_pprint
- eqx.internal.tree_pp

* Added combine(..., is_leaf=...)

* Silenced most test warnings

* Minor wart fixed in Module

* Onboarding docs made more concise. (Less repetition of the word 'PyTree'!)

* Added docstrings

* filter_vmap(out=...) crash fix

* Wrapper modules now preserve __name__, __module__ etc. when flattened and unflattened

* Now guard against non-ShapedArray abstract values

* Tweaks

* Simplified filtered transformation APIs.

This is to (a) make these operations more usable, (b) simplify new-user
onboarding, and (c) improve maintability / bus-factor. The affected
operations are `filter_{grad,value_and_grad,vmap,pmap}`.

Also:

- Added `equinox.if_array` and `equinox.internal.if_mapped`.
- `eqx.filter_{*}` docs now generally explain what they do directly
  instead of referring back to JAX.

* version bump

* typos

* Faster hashable partition/combine

* Added advanced tricks to docs

* Added nicer error message for inconsistent axis size

* Various tweaks (all the best commit messages are descriptive)

* filter_closure_convert no longer has jaxprs as static

* Added while loop speed tests (incomplete). Switched default to filter_jit(donate=none). Moved all nondifferentiable, nonbatchable etc into a single file.

* Buffers in checkpointed_while_loop + other stuff

- checkpointed_while_loop now supports a new `eqxi.Buffer`, which
  allows for outputting results without saving them into the
  checkpoints.
- added speed tests for checkpointed_while_loop
- finalise_jaxpr now traces through custom_jvp, xla_call and
  stop_gradient, to aid its use as a debugging tool.
- Switched to 88-length line cutoff in flake8.
- filter_closure_convert now demands that you call it with arguments
  that look like its example arguments.

* Renamed tree_pp -> pp

* Minor bugfixes for checkpointed_while_loop with max_steps==0

* Added tags to buffers so we can distinguish multilevel checkpointed whiel loops

* Working CWL?

* Nested CWL now working!

* Bugfix for CWL(buffers=<returns nontrivial pytrees>)

* Moved eqxi.checkpointed_while_loop -> eqxi.while_loop, which also supports lax or bounded modes.

* Fixed issue with transpose'd noinline

* Removed old xla.lower_fun usage

* Update versions

* vmap/pmap in_axes can be now be a dictionary

* - Updated to Python 3.8 so renamed __self->self etc.
- filter_custom_vjp now checks that all non-vjp_args are not differentiated.

* doc tweaks

* Tweak test tolerance

* Fixed default checkpoints being wrong

---------

Co-authored-by: uuirs <3000cl@gmail.com>
Co-authored-by: Enver Menadjiev <33788423+enver1323@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support a donate_spec argument in filter_jit
2 participants