Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs for torch.compile(numpy) #109710

Closed
wants to merge 8 commits into from
49 changes: 32 additions & 17 deletions docs/source/torch.compiler_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ parameters of the ``nn.Module``. To get this to work, use
Does NumPy work with ``torch.compile``?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Starting in 2.1., ``torch.compile`` understands native NumPy programs that
Starting in 2.1, ``torch.compile`` understands native NumPy programs that
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.

Expand All @@ -543,33 +543,46 @@ Which NumPy features does ``torch.compile`` support?
NumPy within ``torch.compile`` follows the latest NumPy release.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have some way of getting the version of NumPy that torch.compile works / is tested against?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. This is more of something we would like to have, and a rule for people to submit issues / for us to answer issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, can you file an issue to follow-up on this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no issue to follow up on here. I think we should say that we support the last NumPy release, or I can change it to "the same NumPy release that PyTorch supports". I don't think we have any way to get this, but I remember it's written somewhere. @albanD knows where IIRC. I can change the wording and put a link to the relevant place where we track this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is the following:

    1. clearly we are encoding some NumPy behavior based on some version of NumPy, and as written it seems like this can differ from the installed version of NumPy.
    1. "Latest" as a descriptor isn't super helpful for the user - they could be using an old version of PyTorch or NumPy could have just updated between PyTorch releases.
    1. If we know what 1. is, we should just tell the user (in code!) so they can check and do what they will with that information, such as installing a NumPy version to match or upgrading PyTorch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iii, do we currently have a way to retrieve the NumPy version we support when using it in eager or when running the tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano isn't the north star story here be that PyTorch runtime should match the behavior of the currently installed numpy version (with obviously only a subset of versions being properly supported and some leading to nice unsupported errors).
As of today we're basically targetting a given version of Numpy 2.X right? And I would agree that we should make sure that we have CI running with it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at the moment we are implementing the last version of NumPy (pretty much, modulo that NEP50 point). Then I guess we'll be able to add support for NumPy 2.0 (while keeping support for 1.X as well) once we support it in core.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good. I would just not say "the last version of NumPy" but "the Numpy 2.0 pre-release" so that it doesn't become stale if we don't update it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes complete sense. Updating


Generally, ``torch.compile`` is able to trace through most NumPy constructions,
and when it cannot, it fallbacks to eager and lets NumPy execute that piece of
and when it cannot, it falls back to eager and lets NumPy execute that piece of
code. Even then, there are a few features where ``torch.compile`` semantics
slightly deviate from those of NumPy:

- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
a 0-D array under ``torch.compile``. For performance, it is best to use this 0-D
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
array. If this is not possible, one may often recover the original behavior
lezcano marked this conversation as resolved.
Show resolved Hide resolved
by cast the NumPy scalar to the relevant Python scalar (e.g., ``float``).
by casting the NumPy scalar to the relevant Python scalar type ``bool/int/float``.

- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a question for this documentation, but why don't we fall back in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we found that the pattern x[::-1].fn() where fn is an out-of-place op is relatively common (it's in torchbench even) so I think it's worth this small dissonance. If people don't agree, we can always reconsider and fall back, or even consider implementing this ourselves, as inductor does support negative strides.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider inverting

Suggested change
- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
- Negative strides: views that result in negative strides will instead copy, e.g. ``np.flip`` or slicing with a negative step.


- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.

- ``{tril,triu}_indices_from`` return arrays rather than lists of tuples.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved

There are other feature for which we do not support tracing and we gracefully
There are other features for which we do not support tracing and we gracefully
fallback to NumPy for their execution:

- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
- Long dtypes like ``np.float128`` and ``np.complex256``.

- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64`.

- ``ndarray`` subclasses.

- Masked arrays.
- Ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]``, ``np.add.reduce``, etc.
- Fortran orders and, in general, any ``order=`` different to ``C``.

- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).

- Fortran ordered arrays and, in general, any ``order=`` different to ``C``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry, did not catch it on the first read).
Strictly speaking, some functions have defaults different from 'C' ('A' or 'K'). This does not matter in practice, and we graph break / raise NotImplementedError for non-default order arguments.

Maybe just say any non-default values of the order= arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just going to remove this line as we also graphbreak on a few other inputs and it'd be a pain to list them all.


- Sorting / ordering ``complex64/complex128`` arrays.

- NumPy ``np.poly1d`` and ``np.polynomial``.

- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).

- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.

- ``ndarray.ctypes`` attribute not supported.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Can I execute NumPy code on CUDA via ``torch.compile``?
Expand Down Expand Up @@ -602,7 +615,7 @@ to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:
@torch.compile
def numpy_fn(X: Tensor, Y: Tensor): -> Tensor
X = X.numpy()
Y = X.numpy()
Y = Y.numpy()
# Compute Z here
Z = torch.from_numpy(Z)
return Z
Expand All @@ -613,10 +626,12 @@ to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:
Z = numpy_fn(X, Y)

By doing this, we explicitly create the tensors in CUDA memory, and we keep
them there. Note that the original program would not run on eager mode now.
If you want to run it in eager mode, you would need to call ``.numpy(force=True)``
and perhaps doing ``Z = Z.cuda()`` before returning ``Z``. Of course, doing
this would execute the program on eager mode NumPy, and on CPU.
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler
but no real data movement happens. Note that the original program would not run
on eager mode now. If you want to run it in eager mode, you would need to call
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning
``Z``. Of course, doing this would execute the program on eager mode NumPy, and
on CPU.


How do I debug my ``torch.compile``d NumPy code?
Expand All @@ -642,12 +657,12 @@ function.

Of course, this is not a satisfactory answer if our program is mostly composed
lezcano marked this conversation as resolved.
Show resolved Hide resolved
of NumPy code. In these cases, we can try to execute eagerly (without
lezcano marked this conversation as resolved.
Show resolved Hide resolved
``torch.compile``) the NumPy code on PyTorch by importing ``import torch._numpy
as np``. This should just be used for **debugging purposes** and is in no way a
``torch.compile``) the NumPy code on PyTorch by importing ``import torch._numpy as np``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought, but this may be a useful tool in and of itself. If you could enable a flag and dynamo will not inline from torch._numpy but will do the conversion from numpy to torch._numpy. That way you don't have to edit any code at all to do this test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of changing the imports is that, in most cases, you can execute the code without using dynamo at all. Then, if you see the same bug while doing this, the bug is certainly in the decomposition.

I found this very helpful when debugging issues, as when doing this you are dealing with very simple Python code and debugging stuff is as easy as it gets.

This should just be used for **debugging purposes** and is in no way a
replacement for the PyTorch API, as it is **much less performant** and, as a
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
Python implementation of NumPy in terms of PyTorchand it is used internally to
transform NumPy code into Pytorch code.. It is rather easy to read and modify,
Python implementation of NumPy in terms of PyTorch and it is used internally to
transform NumPy code into Pytorch code. It is rather easy to read and modify,
so if you find any bug in it feel free to submit a PR fixing it!
lezcano marked this conversation as resolved.
Show resolved Hide resolved

If the program does work when importing ``torch._numpy as np``, chances are
Expand Down