Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for treepath-dependent sizes. #136

Merged
merged 13 commits into from
Nov 27, 2023

Conversation

patrick-kidger
Copy link
Owner

No description provided.

Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
…his avoids edge-case crash when using pytree-path dependent sizes
@patrick-kidger patrick-kidger merged commit a7edf26 into structure-matching Nov 27, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the path-matching branch November 27, 2023 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant