Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ONNX export; finalise_jaxpr, nontraceable #243

Merged
merged 1 commit into from
Dec 9, 2022
Merged

Conversation

patrick-kidger
Copy link
Owner

New features:

  • equinox.internal.to_onnx for export to ONNX.
  • equinox.internal.{finalise_jaxpr,finalise_fn,finalise_eval_jaxpr,finalise_jaxpr_as_fun}. The main one here is finalise_jaxpr, which is essentially a jaxpr-to-jaxpr transformation that rewrites all custom primitives in terms of their impl rule. This is useful prior to ONNX export: this can be the final transformation applied to a jaxpr, so that it is now written in terms of primitives that have ONNX export rules. (But naturally this will break anything further jaxpr processing via vmap/grad/etc.)
  • equinox.internal.nontraceable is an operation that cannot be vmap'd, grad'd etc. (Useful to check that there were no closed-over tracers at the end of a final-style higher order primitive.)

@patrick-kidger patrick-kidger merged commit 7e61900 into dev Dec 9, 2022
patrick-kidger added a commit that referenced this pull request Feb 20, 2023
* filter_custom_jvp now supports keyword arguments

* added link to eqxvision

* Support Buffer Donation in filter_jit (#235)

* Adds donate_default, donate_args,
donate_kwargs, and donate_fn in filter_jit;
Adds Tests.

* filter_jit with buffer donation V2;
Tests fixes;

* simplify filter_jit; example fixes;

* Fix review comments

* tweak docs (#242)

* add str2jax test

* Added ONNX export; finalise_jaxpr, nontraceable (#243)

* Fixed up the finalise transformation to handle higher-order primitives (#244)

* feat: Add PReLU activaion (#249)

* Update activations.py

* Started to add a checkpointed while loop

* while fwd

* at_set

* no more at

* Added checkpointed while loop.

* Fix Black

* Many doc improvements. Exposed a few new functions.

New functions:
- eqx.module_update_wrapper
- eqx.tree_pprint
- eqx.internal.tree_pp

* Added combine(..., is_leaf=...)

* Silenced most test warnings

* Minor wart fixed in Module

* Onboarding docs made more concise. (Less repetition of the word 'PyTree'!)

* Added docstrings

* filter_vmap(out=...) crash fix

* Wrapper modules now preserve __name__, __module__ etc. when flattened and unflattened

* Now guard against non-ShapedArray abstract values

* Tweaks

* Simplified filtered transformation APIs.

This is to (a) make these operations more usable, (b) simplify new-user
onboarding, and (c) improve maintability / bus-factor. The affected
operations are `filter_{grad,value_and_grad,vmap,pmap}`.

Also:

- Added `equinox.if_array` and `equinox.internal.if_mapped`.
- `eqx.filter_{*}` docs now generally explain what they do directly
  instead of referring back to JAX.

* version bump

* typos

* Faster hashable partition/combine

* Added advanced tricks to docs

* Added nicer error message for inconsistent axis size

* Various tweaks (all the best commit messages are descriptive)

* filter_closure_convert no longer has jaxprs as static

* Added while loop speed tests (incomplete). Switched default to filter_jit(donate=none). Moved all nondifferentiable, nonbatchable etc into a single file.

* Buffers in checkpointed_while_loop + other stuff

- checkpointed_while_loop now supports a new `eqxi.Buffer`, which
  allows for outputting results without saving them into the
  checkpoints.
- added speed tests for checkpointed_while_loop
- finalise_jaxpr now traces through custom_jvp, xla_call and
  stop_gradient, to aid its use as a debugging tool.
- Switched to 88-length line cutoff in flake8.
- filter_closure_convert now demands that you call it with arguments
  that look like its example arguments.

* Renamed tree_pp -> pp

* Minor bugfixes for checkpointed_while_loop with max_steps==0

* Added tags to buffers so we can distinguish multilevel checkpointed whiel loops

* Working CWL?

* Nested CWL now working!

* Bugfix for CWL(buffers=<returns nontrivial pytrees>)

* Moved eqxi.checkpointed_while_loop -> eqxi.while_loop, which also supports lax or bounded modes.

* Fixed issue with transpose'd noinline

* Removed old xla.lower_fun usage

* Update versions

* vmap/pmap in_axes can be now be a dictionary

* - Updated to Python 3.8 so renamed __self->self etc.
- filter_custom_vjp now checks that all non-vjp_args are not differentiated.

* doc tweaks

* Tweak test tolerance

* Fixed default checkpoints being wrong

---------

Co-authored-by: uuirs <3000cl@gmail.com>
Co-authored-by: Enver Menadjiev <33788423+enver1323@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant