-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add docs for torch.compile(numpy) #109710
Changes from all commits
c191a37
30cb708
b19de98
231962a
499326d
270dcf7
f32be6a
c528d51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. | |||||
CUDA graphs with Triton are enabled by default in inductor but removing | ||||||
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. | ||||||
|
||||||
``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms) | ||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)? | ||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|
||||||
Applying a ``torch.func`` transform to a function that uses ``torch.compile`` | ||||||
does not work: | ||||||
|
@@ -528,6 +528,160 @@ invokes an ``nn.Module``. This is because the outputs now depend on the | |||||
parameters of the ``nn.Module``. To get this to work, use | ||||||
``torch.func.functional_call`` to extract the module state. | ||||||
|
||||||
Does NumPy work with ``torch.compile``? | ||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|
||||||
Starting in 2.1, ``torch.compile`` understands native NumPy programs that | ||||||
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch | ||||||
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions. | ||||||
|
||||||
.. _nonsupported-numpy-feats: | ||||||
|
||||||
Which NumPy features does ``torch.compile`` support? | ||||||
---------------------------------------------------- | ||||||
|
||||||
NumPy within ``torch.compile`` follows NumPy 2.0 pre-release. | ||||||
|
||||||
Generally, ``torch.compile`` is able to trace through most NumPy constructions, | ||||||
and when it cannot, it falls back to eager and lets NumPy execute that piece of | ||||||
code. Even then, there are a few features where ``torch.compile`` semantics | ||||||
slightly deviate from those of NumPy: | ||||||
|
||||||
- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns | ||||||
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D | ||||||
array. If this breaks your code, you can workaround this by casting the NumPy scalar | ||||||
to the relevant Python scalar type ``bool/int/float``. | ||||||
|
||||||
- Negative strides: ``np.flip`` and slicing with a negative step return a copy. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider inverting
Suggested change
|
||||||
|
||||||
- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules | ||||||
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__. | ||||||
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules. | ||||||
|
||||||
- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays. | ||||||
|
||||||
There are other features for which we do not support tracing and we gracefully | ||||||
fallback to NumPy for their execution: | ||||||
|
||||||
- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays. | ||||||
|
||||||
- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``. | ||||||
|
||||||
- ``ndarray`` subclasses. | ||||||
|
||||||
- Masked arrays. | ||||||
|
||||||
- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``). | ||||||
|
||||||
- Sorting / ordering ``complex64/complex128`` arrays. | ||||||
|
||||||
- NumPy ``np.poly1d`` and ``np.polynomial``. | ||||||
|
||||||
- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work). | ||||||
|
||||||
- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``. | ||||||
|
||||||
- ``ndarray.ctypes`` attribute. | ||||||
|
||||||
Can I execute NumPy code on CUDA via ``torch.compile``? | ||||||
------------------------------------------------------- | ||||||
|
||||||
Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")`` | ||||||
context. Consider the example | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
import torch | ||||||
import numpy as np | ||||||
|
||||||
@torch.compile | ||||||
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: | ||||||
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) | ||||||
|
||||||
X = np.random.randn(1024, 64) | ||||||
Y = np.random.randn(1024, 64) | ||||||
with torch.device("cuda"): | ||||||
Z = numpy_fn(X, Y) | ||||||
|
||||||
|
||||||
In this example, ``numpy_fn`` will be executed in CUDA. For this to be | ||||||
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU | ||||||
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are | ||||||
executing this function several times in the same program run, we may want | ||||||
to avoid all these rather expensive memory copies. To do so, we just need | ||||||
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors: | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
@torch.compile | ||||||
def numpy_fn(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: | ||||||
X, Y = X.numpy(), Y.numpy() | ||||||
Z = np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) | ||||||
return torch.from_numpy(Z) | ||||||
|
||||||
X = torch.randn(1024, 64, device="cuda") | ||||||
Y = torch.randn(1024, 64, device="cuda") | ||||||
with torch.device("cuda"): | ||||||
Z = numpy_fn(X, Y) | ||||||
|
||||||
By doing this, we explicitly create the tensors in CUDA memory, and we keep | ||||||
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler | ||||||
but no real data movement happens. Note that the original program would not run | ||||||
on eager mode now. If you want to run it in eager mode, you would need to call | ||||||
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning | ||||||
``Z``. Of course, doing this would execute the program on eager mode NumPy, and | ||||||
on CPU. | ||||||
|
||||||
|
||||||
How do I debug NumPy code under ``torch.compile``? | ||||||
-------------------------------------------------- | ||||||
|
||||||
Debugging JIT compiled code is challenging, given the complexity of modern | ||||||
compilers and the daunting errors that they raise. | ||||||
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__ | ||||||
contains a few tips and tricks on how to tackle this task. | ||||||
|
||||||
If the above is not enough to pinpoint the origin of the issue, there are still | ||||||
a few other NumPy-specific tools we can use. We can discern whether the bug | ||||||
is entirely in the PyTorch code by disabling tracing through NumPy functions: | ||||||
|
||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
from torch._dynamo import config | ||||||
config.trace_numpy = False | ||||||
|
||||||
If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``) | ||||||
using PyTorch as a backend by importing ``import torch._numpy as np``. | ||||||
This should just be used for **debugging purposes** and is in no way a | ||||||
replacement for the PyTorch API, as it is **much less performant** and, as a | ||||||
private API, **may change without notice**. At any rate, ``torch._numpy`` is a | ||||||
Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to | ||||||
transform NumPy code into Pytorch code. It is rather easy to read and modify, | ||||||
so if you find any bug in it feel free to submit a PR fixing it or simply open | ||||||
an issue. | ||||||
|
||||||
If the program does work when importing ``torch._numpy as np``, chances are | ||||||
that the bug is in TorchDynamo. If this is the case, please feel open an issue | ||||||
with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__. | ||||||
|
||||||
I ``torch.compile`` some NumPy code and I did not see any speed-up. | ||||||
------------------------------------------------------------------- | ||||||
|
||||||
The best place to start is the | ||||||
`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__. | ||||||
|
||||||
Some graph breaks may happen because of the use of unsupported features. See | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: both of these features exist in PT as well, so I don't see why this advice should exist in NumPy-specific portion of the documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intuition behind this (and something that I tried to follow both in post and the docs) is that the reader may be someone that comes from NumPy and does not necessarily have too much PyTorch experience. As such, I tried to crossref quite a bit and avoid assuming too much PyTorch knowledge. |
||||||
:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind | ||||||
that some widely used NumPy features do not play well with compilers. For | ||||||
example, in-place modifications make reasoning difficult within the compiler and | ||||||
often yield worse performance than their out-of-place counterparts.As such, it is best to avoid | ||||||
them. Same goes for the use of the ``out=`` parameter. Instead, prefer | ||||||
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes | ||||||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
for data-dependent ops like masked indexing through boolean masks, or | ||||||
data-dependent control flow like ``if`` or ``while`` constructions. | ||||||
|
||||||
|
||||||
Which API to use for fine grain tracing? | ||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a question for this documentation, but why don't we fall back in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but we found that the pattern
x[::-1].fn()
wherefn
is an out-of-place op is relatively common (it's in torchbench even) so I think it's worth this small dissonance. If people don't agree, we can always reconsider and fall back, or even consider implementing this ourselves, as inductor does support negative strides.