Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added de/serialisation of PyTrees #80

Merged
merged 1 commit into from
May 5, 2022
Merged

Added de/serialisation of PyTrees #80

merged 1 commit into from
May 5, 2022

Conversation

patrick-kidger
Copy link
Owner

Closes #46.

This PR adds equinox.tree_{de,}serialise_leaves, primarily for loading/saving models.

@patrick-kidger patrick-kidger merged commit 21a5fbb into v050 May 5, 2022
@patrick-kidger patrick-kidger deleted the serialisation branch May 5, 2022 11:51
patrick-kidger added a commit that referenced this pull request May 6, 2022
* Improved filtered transformations (#71)

* Updated filter_{jit,grad,value_and_grad}

* Documented that BatchNorm supports multiple axis_names

* Doc fix

* added tree_inference

* Added filter_vmap and filter_pmap

* Made filter_{vmap,pmp} non-experimental. Fixed doc issues.

* Added 'inference' call-time argument to 'MultiheadAttention', which was missed when switching over 'Dropout' to use 'inference'.

* Added test for tree_inference

* Doc improvements

* Doc tweaks

* Updated tests

* Two minor bugfixes for filter_{jit,grad,vmap,pmap,value_and_grad}.

1. Applying filter_{jit,vmap,pmap} to a function with *args and **kwargs should now work. Previously inspect.Signature.apply_defaults was not filling these in.
2. The output of all filtered transformations has become a Module. This means that e.g. calling filter_jit(filter_grad(f)) multiple times will not induce any recompilation, as all filter_grad(f)s are PyTrees of the same structure as each other.

* Crash and test fixes

* tidy up

* Fixes and test fixes

* Finished adding tests.

* added is_leaf argument to filter and partition (#72)

* Improved tree_at a lot: (#73)

- Can now substitute arbitrary nodes, not just leaves
- These substituted nodes can now have different structures to each other.
- Will raise an error if `where` depends on the values of the leaves of the PyTree.

In addition, `tree_equal` should now work under JIT.

* Lots of doc fixes and tweaks (#74)

* Added tree_{de,}serialise_leaves (#80)

* Added max/average pooling. (#77)

* Tidy up pooling implementation (#81)

* Version bump

* Fixed tree_at breaking when replacing Nones. (#83)

Co-authored-by: Ben Walker <38265558+Benjamin-Walker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant