Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Syntax changes #13

Closed
patrick-kidger opened this issue Sep 1, 2022 · 12 comments
Closed

[RFC] Syntax changes #13

patrick-kidger opened this issue Sep 1, 2022 · 12 comments

Comments

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 1, 2022

There's a planned rewrite to change the syntax for jaxtyping.

This is available here:
https://github.com/google/jaxtyping/tree/rewrite

In short, the changes are:

  • jaxtyping.Float rather than just jaxtyping.f to denote precision-independent types. (Means we don't have names like i and f that are commonly-used variable names, and a bit opaque as to what they actually mean.)
    • Likewise
       i -> IntSign
       u - > IntUnsign
       t - > Int
       x -> Inexact
       c -> Complex
       n -> Num
      
  • Changing from Float["batch length channels"] to Float[jnp.ndarray, "batch length channels"].
    • If you're not specifying the dtype at all, e.g. jaxtyping.Array["foo bar"], then this is now jaxtyping.Shaped[jnp.ndarray, "foo bar"].
    • For the sake of neat syntax we now have jaxtyping.Array = jnp.ndarray, so that you can use the nicer-looking Array instead of jnp.ndarray, if you wish.

Regarding this latter change:
Pros:

  • Partial compatibility with static type checking. (Hurrah!) Float[jnp.ndarray, "foo"] will now smoothly fall back to being treated as just jnp.ndarray by static type checkers, instead of just being hopelessly incompatible.
  • The ability to specify non-JAX-array types! Including NumPy/TensorFlow/PyTorch.
  • Compatibility with the upcoming jax.typing namespace. Which has more limited aims of static type checking support; the plan is for jaxtyping to be a superset of jax.typing.

Cons:

  • A little more verbose.

Particular questions I'd welcome feedback on:

  • Alternate project names? Having both jax.typing and jaxtyping is a bit confusing. (But perhaps not too bad, if jaxtyping is a strict superset.)
  • Alternate names for IntSign and IntUnsign? (SInt? IntS?)
  • Alternate names for Float, e.g. just F?
  • Alternate names for f32, e.g. Float32? (At present these are still the short ones.)
  • Any strong opinions about the Float["foo"] -> Float[Array, "foo"] change mentioned above? (I've already chatted with a few Google-internal people who have strong opinions both for and against this, haha.)

CC @thomaspinder @heytanay @daniel-dodd @ayaka14732 as the folks GitHub currently lists as using this project in public repos.

@patrick-kidger
Copy link
Owner Author

CC @leycec for curiosity, as someone who evidently has a lot of strong opinions on type-checking. (Also you might like the trick through which Float[Array, "foo"] is made to work with static type checkers.)

@thomaspinder
Copy link

Thanks for the heads up @patrick-kidger. I'm OK with the look of both proposed changes, though I do particularly like the explicitness of the former.

To answer your specific questions in order:

  1. Personally, I'm fine with this.
  2. I like the verbosity of IntSign and IntUnsign. It is a few more characters, but it mitigates any confusion i.e., IntS could be misconstrued as multiple Ints.
  3. For my above reasoning, I like Float
  4. As above.
  5. Whilst by no means a strong opinion, I prefer Float["foo"] over Float[Array, "foo"]. Everything in GPJax where we use JaxTyping is an Array, so the additional line is somewhat superfluous.

As an aside, it would be cool to see more mathematically inspired typing systems e.g., ColumnVector['N'] would type a Nx1 vector. Extensions would be RowVector, Scalar, Matrix and Tensor. I have no idea how feasible this is and how you feel about it but I'm happy to open a separate issue where we can discuss further if it is of interest though.

@ayaka14732
Copy link

ayaka14732 commented Sep 1, 2022

Would Float[jnp.ndarray, "batch length channels"] be a bit confusing? I am thinking that something like jnp.ndarray[Float, "batch length channels"] or FloatArray[jnp.ndarray, "batch length channels"] would look more intuitive

@ayaka14732
Copy link

@thomaspinder

Whilst by no means a strong opinion, I prefer Float["foo"] over Float[Array, "foo"]. Everything in GPJax where we use JaxTyping is an Array, so the additional line is somewhat superfluous.

A common use case I can think of is that when the dataset is too large to fit into the TPU memory, we usually load the entire dataset into the CPU memory as a NumPy array first, then slice the dataset into small batches and convert them to JAX arrays on TPU. In this case it would be necessary to distinguish between np.array and jnp.array. (See google/jax#8933 (comment))

Besides, I am also wondering if it is possible to make np.array/jnp.array an optional second argument, so that we can just omit it if we are indifferent to the type.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Sep 1, 2022

Thank you both for your thoughts. :)

Personally I agree with both of you; I'd much rather have Float["foo bar"] or Array[Float, "foo bar"] as well. However the Float[Array, "foo bar"] syntax is the best approach I've been able to find that has any compatibility with static type checkers.

This is necessary since jax.typing will happen at some point. They're happy to work with jaxtyping to ensure we don't bifurcate into incompatible approaches, but static type checking is their goal, so we need to work with them in turn. (I'm all ears to any suggestions that will pass both mypy+pytype.)

[+I'd quite like to support static type checking as well -- it's a pretty useful thing, when it works.]

@thomaspinder - regarding point 4, you mean that you would rather see Float32 over f32? (I don't have a strong preference myself, so I'd be happy to change this.)

@ayaka14732 - indeed, specifying numpy/JAX/etc. backend is also desirable. It's a lesser priority, but supporting this is also a goal of the proposed new syntax. (Note that it's quite difficult to have optional arguments for static type checkers; off the top of my head the only things which support this are typing.{Literal, Annotated, Tuple}, and variadic generics. An Annotated solution is probably the most likely to work.)

@thomaspinder
Copy link

thomaspinder commented Sep 1, 2022

Interesting. Thanks for the clarification. If it enables static type checking, then the Float[Array, "foo bar"] syntax seems like a necessary sacrifice.

@patrick-kidger Yes - I, personally, would prefer Float32 over f32.

@PhilipVinc
Copy link

Alternate project names? Having both jax.typing and jaxtyping is a bit confusing. (But perhaps not too bad, if jaxtyping is a strict superset.)

Taking inspiration from typing-extensions official python package, maybe rename jaxtyping to jaxtyping-extensions?

@ayaka14732
Copy link

jaxtyping and jax.typing

  1. I hope that static type checking is not a main goal, if it would make things more complicated. I also think that compatibility is not necessary, since it would often increase the complexity. And having two separate projects would not be a bad thing, because we are free to choose between the two projects.
  2. As a consequence, the project name would have to be changed. Moreover, considering that we will be supporting not only jnp.array, but also NumPy/PyTorch/TensorFlow as well, the project name may not contain the word JAX.
  3. I suggest tensortyper as the new project name. It is easy to pronounce since both words starts with t and ends with r.

Syntax changes

  1. I think longer names like Float are better than shoter ones like f because 'explicit is better than implicit'. And if we really want to use shorter names, we can always do from jaxtyping import Float as f.
  2. If we are going to use the syntax Float[Array, "foo bar"], I hope that it could be FloatArray[Array, "foo bar"]. This is because the former sounds like a certain type of float instead of an array, similar to that List[int] is a certain type of list.
  3. I prefer IntSign/SignInt to IntS/SInt because 'explicit is better than implicit'. Besides, I think it should be IntSigned/SignedInt because 'signed' is the adjective.

@irhum
Copy link

irhum commented Sep 7, 2022

Thanks for this RFC @patrick-kidger! Regarding the naming for types, I'd say "adherence to the NumPy API" to be one axis to evaluate syntax changes on. JAX's own potential stems from "it's just NumPy", and this would help onboard new users as quickly as possible. By that line of reasoning:

  • Verbosity: Types should be more verbose; not just NumPy, but also Python's own "explicit is better than implicit" plays into Float32 feeling more "pythonic" than f32.
  • Specific Cases: Extending this reasoning, I'd prefer to see syntax such as Int and UInt over IntSign and IntUnsign, since they're closer to the established np.int and np.uint.
  • Imports: A question I've been thinking is: should types be direct imports (eg. from typing import Tuple) or module-level imports (e.g. import numpy as np, then use np.float32)? I could easily see either Float or jxt.float style syntax working (where jxt would be import jaxtyping as jxt). After discussion, I personally would lean towards the former (direct imports), since these classes exist to aid type-checking.
  • Open Question: How challenging would it be under Python to have a type such as Array[Float, "b t d"] as opposed to Float[Array, "b t d"] as proposed here? I'm curious if this is a design choice, or reflection of more fundamental limitations in the language, since the former would be closer to how other languages (such as Julia's Array{Float32}) present their types.

I'll add this: I'm enjoying using this library; the multi-argument runtime type-checking over tensors is a brilliant idea, makes me go "how was I not checking this before?". The fact it just works with all the other JAX transforms (jit, pmap, etc.) is equally neat. The syntax changes shouldn't take away from the underlying achievement that, even as the library is right now, it works well! As this RFC is elaborated on, I'm curious to see the fleshing out of the "default, recommended" way to use jaxtyping going forward.

patrick-kidger added a commit that referenced this issue Sep 7, 2022
- Float32 introduced to replace f32 (etc.)
- Old f32 aliases were previously around for backward-compatibility, but
  the code is pretty hideous, and all stakeholders are now on board
  through RFC #13.
- Overall feedback was that Int{Sign,Unsign} was too long and not enough
  like numpy. Changed to Int, UInt, and Integer.
@patrick-kidger patrick-kidger mentioned this issue Sep 7, 2022
@patrick-kidger
Copy link
Owner Author

Thanks all for your feedback. The new version has now been merged: #16.

(In particular I appreciate all the positive feedback to the effect of "jaxtyping is super useful, thanks" -- this warms a library author's heart.)

To respond to the last round of comments raised:

  • I liked the suggestion of @PhilipVinc to rename to jaxtyping-extensions, and came close to doing this. In the end I decided against this just because a shorter name is honestly a little more marketable. (+one fewer backward-incompatible change)
  • The general feedback was a preference for Float32 over Float, so that has now happened.
  • Lots of different ideas on what to call the integer types. In the end I decided to follow NumPy: Int for signed integers, UInt for unsigned integers, and Integer for both.
  • With apologies to @ayaka14732, the goal of these changes was in large part to provide compatibility with static typing, which is important for compatibility with core JAX.
  • @irhum - indeed Array[Float, "foo bar"] would have been a nice syntax, but unfortunately this is incompatible with static typing.

Once again, thankyou everyone for your engagement!

@riven314
Copy link

riven314 commented Sep 8, 2022

@patrick-kidger
I am sorry if my question is out of scope, but I have a question on multi-arguments type checking as I see this concept is mentioned in the thread:
in the context of tensor shape, what is something special about multi-arguments type checking (v.s. single argument type checking)? Would you mind briefly explaining how it works differently under multi-arguments? (I couldn't find related references with a few searches) And why it has to be decorated by @jaxtyped and @typechecker?

@patrick-kidger
Copy link
Owner Author

The multi-argument checking is in reference to checking that, across multiple arguments to a function, the sizes of multiple arrays should agree.

For example def foo(x: Shaped[Array "bar"], y: Shaped[Array, "bar"]) should be called with two arrays of the same size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants