-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Syntax changes #13
Comments
CC @leycec for curiosity, as someone who evidently has a lot of strong opinions on type-checking. (Also you might like the trick through which |
Thanks for the heads up @patrick-kidger. I'm OK with the look of both proposed changes, though I do particularly like the explicitness of the former. To answer your specific questions in order:
As an aside, it would be cool to see more mathematically inspired typing systems e.g., |
Would |
A common use case I can think of is that when the dataset is too large to fit into the TPU memory, we usually load the entire dataset into the CPU memory as a NumPy array first, then slice the dataset into small batches and convert them to JAX arrays on TPU. In this case it would be necessary to distinguish between Besides, I am also wondering if it is possible to make |
Thank you both for your thoughts. :) Personally I agree with both of you; I'd much rather have This is necessary since [+I'd quite like to support static type checking as well -- it's a pretty useful thing, when it works.] @thomaspinder - regarding point 4, you mean that you would rather see @ayaka14732 - indeed, specifying numpy/JAX/etc. backend is also desirable. It's a lesser priority, but supporting this is also a goal of the proposed new syntax. (Note that it's quite difficult to have optional arguments for static type checkers; off the top of my head the only things which support this are |
Interesting. Thanks for the clarification. If it enables static type checking, then the @patrick-kidger Yes - I, personally, would prefer |
Taking inspiration from typing-extensions official python package, maybe rename |
|
Thanks for this RFC @patrick-kidger! Regarding the naming for types, I'd say "adherence to the NumPy API" to be one axis to evaluate syntax changes on. JAX's own potential stems from "it's just NumPy", and this would help onboard new users as quickly as possible. By that line of reasoning:
I'll add this: I'm enjoying using this library; the multi-argument runtime type-checking over tensors is a brilliant idea, makes me go "how was I not checking this before?". The fact it just works with all the other JAX transforms ( |
- Float32 introduced to replace f32 (etc.) - Old f32 aliases were previously around for backward-compatibility, but the code is pretty hideous, and all stakeholders are now on board through RFC #13. - Overall feedback was that Int{Sign,Unsign} was too long and not enough like numpy. Changed to Int, UInt, and Integer.
Thanks all for your feedback. The new version has now been merged: #16. (In particular I appreciate all the positive feedback to the effect of "jaxtyping is super useful, thanks" -- this warms a library author's heart.) To respond to the last round of comments raised:
Once again, thankyou everyone for your engagement! |
@patrick-kidger |
The multi-argument checking is in reference to checking that, across multiple arguments to a function, the sizes of multiple arrays should agree. For example |
There's a planned rewrite to change the syntax for
jaxtyping
.This is available here:
https://github.com/google/jaxtyping/tree/rewrite
In short, the changes are:
jaxtyping.Float
rather than justjaxtyping.f
to denote precision-independent types. (Means we don't have names likei
andf
that are commonly-used variable names, and a bit opaque as to what they actually mean.)Float["batch length channels"]
toFloat[jnp.ndarray, "batch length channels"]
.jaxtyping.Array["foo bar"]
, then this is nowjaxtyping.Shaped[jnp.ndarray, "foo bar"]
.jaxtyping.Array = jnp.ndarray
, so that you can use the nicer-lookingArray
instead ofjnp.ndarray
, if you wish.Regarding this latter change:
Pros:
Float[jnp.ndarray, "foo"]
will now smoothly fall back to being treated as justjnp.ndarray
by static type checkers, instead of just being hopelessly incompatible.jax.typing
namespace. Which has more limited aims of static type checking support; the plan is forjaxtyping
to be a superset ofjax.typing
.Cons:
Particular questions I'd welcome feedback on:
jax.typing
andjaxtyping
is a bit confusing. (But perhaps not too bad, ifjaxtyping
is a strict superset.)IntSign
andIntUnsign
? (SInt
?IntS
?)Float
, e.g. justF
?f32
, e.g.Float32
? (At present these are still the short ones.)Float["foo"] -> Float[Array, "foo"]
change mentioned above? (I've already chatted with a few Google-internal people who have strong opinions both for and against this, haha.)CC @thomaspinder @heytanay @daniel-dodd @ayaka14732 as the folks GitHub currently lists as using this project in public repos.
The text was updated successfully, but these errors were encountered: