Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Numpy API (interoperability between ML frameworks) #94779

Open
17 tasks
Conchylicultor opened this issue Feb 13, 2023 · 6 comments
Open
17 tasks

Better Numpy API (interoperability between ML frameworks) #94779

Conchylicultor opened this issue Feb 13, 2023 · 6 comments
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Conchylicultor
Copy link

Conchylicultor commented Feb 13, 2023

Context

For better interoperability between ML frameworks, it would be great if torch API matched more closely numpy API (like tf.experimental.numpy and jax.numpy).

This is a highly requested features, like: #2228 (~100 upvotes), #50344, #38349,... Those issues have since been closed even though there's still many issues remaining.

This is even more relevant with the numpy API standard (NEP 47): The goal is to write functions once, then reuse them across frameworks:

def somefunc(x, y):
    xnp = get_numpy_module(x, y)  # Returns `np`, `jnp`, `tf.numpy`, `torch`
    out = xnp.mean(x, axis=0) + 2*xnp.std(y, axis=0)
    return out

Our team has multiple universal libraries which support both numpy, jax and TF (like dataclass_array or visu3d).

We've been experimenting adding torch support recently but encountered quite a lot of issues (while tf.numpy, jax worked (mostly) out of the box). Here are all the issues we're encountered:

numpy API issues

Some common methods (present in np, jnp, tf.numpy) are missing from torch:

Behavior:

Casting:

  • x.mean() currently require explicit dtype when x.dtype == int32 (x.mean(float32)). Other frameworks default to the default float type (float32)
  • torch.allclose(x, y) currently fail if x.dtype != y.dtype, which is inconsistent with np, jnp, tf (this is very convenient in tests np.allclose(x, [1, 2, 3]))

Other differences (but not critical to fix):

Testing and experimenting

Those issues have been found in real production code. Our projects have a @enp.testing.parametrize_xnp to run the same unittests on tf, jax, numpy, torch to make sure our code works on all backends...

For example: https://github.com/google-research/visu3d/blob/89d2a6c9cb3dee1d63a2f5a8416272beb266510d/visu3d/math/rotation_utils_test.py#L29

In order to make our tests pass with torch, we had to mock torch to fix some of those behaviors: https://github.com/google/etils/blob/main/etils/enp/torch_mock.py

Having a universal standard API that all ML frameworks apply would be a great step towards. I hope this issue is a small step to help toward this goal.

cc @mruberry @rgommers

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Feb 13, 2023

Another NumPy compat issue on supporting string dtypes: #40568 (comment) (useful for porting existing numpy code)

And existing issue for append: #64359

@cpuhrsch cpuhrsch added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Feb 14, 2023
@cpuhrsch
Copy link
Contributor

cc @jisaacso

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 14, 2023
@rgommers
Copy link
Collaborator

Hi @Conchylicultor, thank you for the detailed write-up! The goal of adding torch support in addition to numpy and jax sounds great.

There's two separate parts here: the big-picture "PyTorch compatibility with NumPy", and the presence and behavior of individual objects in the API. For the big-picture part:

  • Let me note that there is no "numpy API standard (NEP 47)", but it's the Python array API standard (https://data-apis.org/array-api/latest/), which is a lot like the core of NumPy - but also has significant changes to make it better fit PyTorch, JAX, and other GPU/JIT/etc libraries. For NumPy, NEP 47 introduces a separate submodule, numpy.array_api, but that's a bit clumsy - for NumPy 2.0 (target Dec'23) we should have the main namespace fully compatible with it.
  • For PyTorch's compatibility with the NumPy API, let me point out that there's an effort ongoing to implement full support via TorchDynamo (graph mode only). See https://github.com/Quansight-Labs/numpy_pytorch_interop/ and Can we rewrite numpy operators to pytorch operators? #93684.
  • For PyTorch's compatibility with the array API standard, there's the array-api-compat package, and full support for PyTorch should land by the end of this week (via PyTorch compatibility layer data-apis/array-api-compat#14). That package is designed to be very lightweight and vendorable, and is probably a good way to add PyTorch support in your projects.

Having a universal standard API that all ML frameworks apply would be a great step towards. I hope this issue is a small step to help toward this goal.

+1

Here are all the issues we're encountered:

These are a mix of things that should be added to PyTorch, things that would be nice to add/change but may be challenging because of backwards compatibility (array-api-compat and/or the TorchDynamo support for the NumPy API should help here hopefully), and things that are best avoided in your own code. I'll post responses for each item in a next comment.

@rgommers
Copy link
Collaborator

Some common methods (present in np, jnp, tf.numpy) are missing from torch:

  • torch.array: like x = xnp.array([1, 2, 3])) (alias of torch.tensor)

np.array is best avoided in your code; numpy unfortunately has a lot of baggage here, but it's essentially the same as np.asarray. And there's a matching torch.asarray.

  • torch.ndarray: like isinstance(x, xnp.ndarray) (alias of torch.Tensor)

This one I'm not 100% sure about either way. We left naming the array object out of the array API standard on purpose, because there's a bunch of different names floating around. isinstance checks aren't used a lot because they conflict with array subclasses (many libraries apply asarray to their inputs to avoid non-Liskov substitutable subclasses), and also with __array_function__/__torch_function__.

  • Tensor.astype: like x = x.astype(np.uint8)

Yes, this is a gap. I'll note that the array API standard has it as an astype function rather than a method. PyTorch has neither.

Could be added indeed. Not high-prio I think, because append is more often misused than used correctly when needed (and stack is a good alternative), especially by beginning users who treat arrays like lists and append in a loop.

This would be nice to add as an alias indeed, and easy to do.

I'll note that there's quite a few functions like these in numpy like max/amax, range/arange, round/around, etc. The ones without a prepended should be preferred in general, and I am planning to do a significant cleanup here for numpy 2.0. np.round does work, even though it's also the one where it's not clear that it's the preferred alias. Both PyTorch and the array API standard also have round.

Behavior:

There doesn't seem to be much engagement on gh-40568. It could be done, but it's kind of weird of course. Having the same names should be enough, just like for other objects (like regular functions) in the namespaces. The intended pattern here is np.func(x, dtype=xp.float32), which works fine today.

  • torch.dtype should be comparable with np.dtype: tf.int32 == np.int32 but torch.int32 != np.int32. This allow to have agnostic comparison: x.dtype == np.uint8 working for all frameworks.

Similar - could be added, and would be fairly pragmatic, but also technically incorrect. With numpy dtypes you can for example create scalars (np.float32(0.5) gives you an object that PyTorch/JAX/etc. do not have).

After a lot of discussion, we settled on isdtype in the array API standard. So comparisons like you want here will be isdtype(x.dtype, xp.int32).

This is a fairly annoying UX papercut in PyTorch indeed. It'd be great if all integer tensors could be used for indexing.

This one I've wanted to see fixed for a long time; it's not high-prio as well as bc-breaking unfortunately. The array-api-compat package should help here.

  • torch.ones(shape=()) fail (expect size=), but xnp.ones(shape=()) works (in other frameworks). Same for torch.zeros,...

Use it as a positional keyword to avoid this - that's more idiomatic anyway, and the array API standard makes these arguments positional-only so this will not be an issue.

Related: PyTorch using Tensor.size where NumPy et al. use .shape is perhaps the single biggest issue in compatibility; it won't be fixed unfortunately, since use of .size by PyTorch users is too widespread.

Casting:

  • x.mean() currently require explicit dtype when x.dtype == int32 (x.mean(float32)). Other frameworks default to the default float type (float32)

Would be nice to fix, but is a minor papercut. I'll also note that the array API standard mandates only floating point dtype support for mean, primarily because TensorFlow in general rejects cross-kind casting.

  • torch.allclose(x, y) currently fail if x.dtype != y.dtype, which is inconsistent with np, jnp, tf (this is very convenient in tests np.allclose(x, [1, 2, 3]))

This would be good to fix up indeed.

Other differences (but not critical to fix):

That does work for operators. I'm not sure it can be improved upon without accepting numpy arrays to every PyTorch function (which I don't think is a good idea).

  • Mixing float and double

That seems useful indeed; I think it works in many places, but probably not 100% consistently.

  • Would be nice if torch.asarray was supporting jax, tf tensors (and vice-versa).

This should be made to work via every array/tensor object supporting __dlpack__. PyTorch does support this (see docs in from_dlpack), so this may be more of a TF/JAX feature request? I'd have to check in more detail to be sure.

@Conchylicultor
Copy link
Author

Conchylicultor commented Mar 13, 2023

Thanks for the answer.

I'll note that the array API standard has it as an astype function rather than a method

This feels strange to me x.astype(np.int32) feels much less verbose than np.astype(x, np.int32). For such a common operation (e.g. casting back to uint8 for images), I wish array.astype was kept.

Is there a place to report such issues ?

The intended pattern here is np.func(x, dtype=xp.float32), which works fine today.

Unfortunately it doesn't work with np.bool_:

xp.zeros((), dtype=xp.bool_)  # torch has no attribute `torch.bool_`

It looks like the numpy API will have np.bool, but in the meantime, this is quite unfortunate.


  • Another unmatching API is the np.concatenate in jnp and tnp but torch.concat / torch.cat.

Overall, I was able to fix those issues by adding a compat.astype(), compat.expand_dims(),... and so on. It's a little unfortunate that this is required, but at least this unblock us in the meantime.

My feeling is that multi-framework support today is way more complicated than it should. Because of all those small issues, it actually require quite a lot of effort to make the code compatible with torch/TF and often require custom wrapper/helper.

@rgommers
Copy link
Collaborator

Is there a place to report such issues ?

Yes, the issue tracker of the standard. Here is the relevant PR for reference: data-apis/array-api#290.

Unfortunately it doesn't work with np.bool_:

Yes, indeed - NumPy really needs to reinstate np.bool. That should happen by 2.0 at the end of the year. In the meantime, the array-api-compat package is bridging these differences as much as possible (and it has bool).

Another unmatching API is the np.concatenate in jnp and tnp but torch.concat / torch.cat.

I expect that we'll add concat to NumPy this year (numpy/numpy#16469).

My feeling is that multi-framework support today is way more complicated than it should. Because of all those small issues, it actually require quite a lot of effort to make the code compatible with torch/TF and often require custom wrapper/helper.

Yes, you're completely right. That's why we've spent so much effort on standardization; that now needs to be rolled out in the main namespaces. Some issues are going to remain though, due to backwards compatibility concerns. For PyTorch the most prominent one is Tensor.size, which matches .shape in other libraries rather than .size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants