Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.5.0 #82

Merged
merged 10 commits into from
May 6, 2022
Merged

Version 0.5.0 #82

merged 10 commits into from
May 6, 2022

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented May 6, 2022

This is a big update.

Exciting new features!

Refactoring for nicer APIs

  • filter_{jit,grad,value_and_grad} now have an easier-to-use API for specifying which arguments have what behaviour.

    • Instead of having to specify (args, kwargs) as a single PyTree, then you can specify a default, args, kwargs separately. In particular this avoids doing messy stuff like filter_spec=((...), {}) when you had no kwargs.
    • You no longer have to match up the filter specification for args and kwargs against their runtime values. Both the runtime values, and the filter specification, are matched up against the function signature.
      e.g. you can do filter_jit(lambda x: x, kwargs=dict(x=True))(1), using a keyword argument at JIT-time and a positional argument at call time.
    • Currying is available: both filter_jit(fun) and filter_jit(default=...)(fun) will work.
    • The old API is still available for backward compatibility, of course.
    • (Closes Neater interface to filtered transformations? #48.)
  • tree_at can now replace subtrees, and not just leaves.

  • filter, partition now support an is_leaf argument.

Miscellaneous

patrick-kidger and others added 8 commits April 25, 2022 11:48
* Updated filter_{jit,grad,value_and_grad}

* Documented that BatchNorm supports multiple axis_names

* Doc fix

* added tree_inference

* Added filter_vmap and filter_pmap

* Made filter_{vmap,pmp} non-experimental. Fixed doc issues.

* Added 'inference' call-time argument to 'MultiheadAttention', which was missed when switching over 'Dropout' to use 'inference'.

* Added test for tree_inference

* Doc improvements

* Doc tweaks

* Updated tests

* Two minor bugfixes for filter_{jit,grad,vmap,pmap,value_and_grad}.

1. Applying filter_{jit,vmap,pmap} to a function with *args and **kwargs should now work. Previously inspect.Signature.apply_defaults was not filling these in.
2. The output of all filtered transformations has become a Module. This means that e.g. calling filter_jit(filter_grad(f)) multiple times will not induce any recompilation, as all filter_grad(f)s are PyTrees of the same structure as each other.

* Crash and test fixes

* tidy up

* Fixes and test fixes

* Finished adding tests.
- Can now substitute arbitrary nodes, not just leaves
- These substituted nodes can now have different structures to each other.
- Will raise an error if `where` depends on the values of the leaves of the PyTree.

In addition, `tree_equal` should now work under JIT.
@patrick-kidger patrick-kidger changed the base branch from main to delete-me May 6, 2022 19:30
@patrick-kidger patrick-kidger changed the base branch from delete-me to main May 6, 2022 19:30
@patrick-kidger patrick-kidger merged commit 291c4d7 into main May 6, 2022
@patrick-kidger patrick-kidger deleted the v050 branch May 6, 2022 21:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants