Added ONNX export; finalise_jaxpr, nontraceable #243
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New features:
equinox.internal.to_onnx
for export to ONNX.equinox.internal.{finalise_jaxpr,finalise_fn,finalise_eval_jaxpr,finalise_jaxpr_as_fun}
. The main one here isfinalise_jaxpr
, which is essentially a jaxpr-to-jaxpr transformation that rewrites all custom primitives in terms of theirimpl
rule. This is useful prior to ONNX export: this can be the final transformation applied to a jaxpr, so that it is now written in terms of primitives that have ONNX export rules. (But naturally this will break anything further jaxpr processing via vmap/grad/etc.)equinox.internal.nontraceable
is an operation that cannot be vmap'd, grad'd etc. (Useful to check that there were no closed-over tracers at the end of a final-style higher order primitive.)