-
-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Big release: version 10! #260
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* Adds donate_default, donate_args, donate_kwargs, and donate_fn in filter_jit; Adds Tests. * filter_jit with buffer donation V2; Tests fixes; * simplify filter_jit; example fixes; * Fix review comments
New functions: - eqx.module_update_wrapper - eqx.tree_pprint - eqx.internal.tree_pp
This is to (a) make these operations more usable, (b) simplify new-user onboarding, and (c) improve maintability / bus-factor. The affected operations are `filter_{grad,value_and_grad,vmap,pmap}`. Also: - Added `equinox.if_array` and `equinox.internal.if_mapped`. - `eqx.filter_{*}` docs now generally explain what they do directly instead of referring back to JAX.
…_jit(donate=none). Moved all nondifferentiable, nonbatchable etc into a single file.
- checkpointed_while_loop now supports a new `eqxi.Buffer`, which allows for outputting results without saving them into the checkpoints. - added speed tests for checkpointed_while_loop - finalise_jaxpr now traces through custom_jvp, xla_call and stop_gradient, to aid its use as a debugging tool. - Switched to 88-length line cutoff in flake8. - filter_closure_convert now demands that you call it with arguments that look like its example arguments.
…ports lax or bounded modes.
- filter_custom_vjp now checks that all non-vjp_args are not differentiated.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Highlights
A dramatically simplified API for
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
. This is a backward-incompatible change.equinox.internal.while_loop
, which is a reverse-mode autodifferentiable while loop, using recursive checkpointing.Full change list
New features
Some new relatively minor new features available in this release.
eqx.{filter_jit, filter_pmap}
.eqx.nn.PRelu
.eqx.tree_pprint
.eqx.module_update_wrapper
.eqx.filter_custom_jvp
now supports keyword arguments (which are always treated as nondifferentiable).New
internal
featuresIntroducing a slew of new features for the advanced JAX user.
These are all available in the
equinox.internal
namespace. Note that these comes without stability guarantees, as they often depend on functionality that JAX doesn't make fully public.eqxi.abstractattribute
, for marking abstract instance attributes of abstract Equinox modules.eqxi.tree_pp
, for producing a pretty-print doc of an object. (This is what is then formatted to a particular width in e.g.eqx.tree_pformat
.) In addition classes can now have custom pretty behaviour when used witheqx.{tree_pp, tree_pformat, tree_pprint}
, by setting a__tree_pp__
method.eqxi.if_mapped
, as an alternative to the usualeqx.if_array
passed toeqx.{filter_vmap, filter_pmap}(out_axes=...)
.eqxi.{finalise_jaxpr, finalise_fn}
for tracing through custom primitivesimpl
rules (so that the custom primitive no longer appears in the jaxpr). This is useful for replacing such custom primitives prior to offloading a jaxpr to some other IR, e.g. viajax2tf
.eqxi.{nonbatchable, nondifferentiable, nondifferentiable_backward, nontraceable}
for asserting that an operation is never batched, differentiated, or subject to any transform at all.eqxi.to_onnx
for exporting to ONNX.eqxi.while_loop
for reverse-mode autodifferentiable while loops; in particular making use of recursive checkpointing. (A la treeverse.)Backward-incompatible changes
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
has been dramatically simplified. If you were using the extra arguments to these functions (i.e. not just calling@eqx.filter_jit
etc. directly) then this is a backward-incompatible change; see the discussion below for more details.equinox.nn.{AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D}
. UseAvgPool1d
etc. (lower-case "d") instead. (These were backward-compatiblity stubs that have now been removed.)equinox.Module.{tree_flatten, tree_unflatten}
. These were never technically public API; usejax.tree_util.{tree_flatten, tree_unflatten}
instead.equinox.filter_closure_convert
now asserts that you call it with argments compatible with those it was closure-converted with.Other
filter_jit
orfilter_pmap
boundary should now be much reduced.eqx.tree_inference
now runs faster.Filtered transformation API changes (AKA: "my code isn't working any more?")
These APIs have been simplified and made much easier to understand. No functionality has been lost, things might just need tweaking.
filter_jit
This previously took
default
,args
,kwargs
,out
,fn
arguments, for controlling what should be traced and what should be held static.In practice all JAX arrays and NumPy arrays always had to be traced, and everything that wasn't a JAXable type (JAX array, NumPy array,
bool
,int
,float
,complex
) had to be held static. So these arguments just weren't that useful: pretty much the only thing you could do with them was to specify that you'd like to trace abool
/int
/float
/complex
.This minor use-case wasn't worth complicating such an important API for, which is why these arguments have been removed.
If after this change you still want to trace with respect to
bool
/int
/float
/complex
, then do so simply by wrapping them into JAX arrays or NumPy arrays first:np.asarray(x)
.filter_grad
andfilter_value_and_grad
These previously took an
arg
argument, for controlling what parts of the first argument should be differentiated.This was useful occasionally -- e.g. when freezing parts of a layer -- but in practice it still wasn't used that often. As such it this argument has been removed for the sake of simplicity.
If after this change you want to replicate the previous behaviour, then it is simple to do so using
partition
andcombine
:See also the updated frozen layer example for a demonstration.
filter_vmap
This previously took
default
,args
,kwargs
,out
,fn
arguments, for controlling what axes should be vectorised over.In practice this API was just a bit more complicated than it really needed to be. The only useful feature relative to
jax.vmap
waskwargs
, for easily specifying just a few named arguments that should behave differently.The new API instead accepts
in_axes
andout_axes
arguments, just likejax.vmap
. To replacekwargs
, one extra feature is supported:in_axes
may be a dictionary of named argments, e.g.All arguments not named in
kwargs
will have the default value ofeqx.if_array(0) -> 0 if is_array(x) else None
applied to them.On which note, a new
eqx.if_array(i)
now exists, to make it easier to specify values forin_axes
andout_axes
.If you were using the old
fn
argument, then this can be replicated by instead decorating a function that accepts the callable:filter_pmap
.This previously took
default
,args
,kwargs
,out
,fn
arguments, for controlling what axes should be parallelised over, and which arguments should be traced vs static.This was a fiendishly complicated API merging together both the
filter_jit
andfilter_vmap
APIs.The JIT part of it is now handled automatically, as with
filter_jit
: all arrays are traced, everything else is static.The vmap part of it is now handled in the same way as
filter_vmap
, usingin_axes
andout_axes
arguments.