Skip to content

Commit

Permalink
Add docs for torch.compile(numpy)
Browse files Browse the repository at this point in the history
ghstack-source-id: 3e29b38d0bc574ab5f35eee34ebb37fa6238de7e
Pull Request resolved: #109710
  • Loading branch information
lezcano committed Sep 21, 2023
1 parent 5aae979 commit 1c85031
Showing 1 changed file with 156 additions and 2 deletions.
158 changes: 156 additions & 2 deletions docs/source/torch.compiler_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
CUDA graphs with Triton are enabled by default in inductor but removing
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.

``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Applying a ``torch.func`` transform to a function that uses ``torch.compile``
does not work:
Expand Down Expand Up @@ -528,6 +528,160 @@ invokes an ``nn.Module``. This is because the outputs now depend on the
parameters of the ``nn.Module``. To get this to work, use
``torch.func.functional_call`` to extract the module state.

Does NumPy work with ``torch.compile``?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Starting in 2.1, ``torch.compile`` understands native NumPy programs that
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.

.. _nonsupported-numpy-feats:

Which NumPy features does ``torch.compile`` support?
----------------------------------------------------

NumPy within ``torch.compile`` follows NumPy 2.0 pre-release.

Generally, ``torch.compile`` is able to trace through most NumPy constructions,
and when it cannot, it falls back to eager and lets NumPy execute that piece of
code. Even then, there are a few features where ``torch.compile`` semantics
slightly deviate from those of NumPy:

- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
array. If this breaks your code, you can workaround this by casting the NumPy scalar
to the relevant Python scalar type ``bool/int/float``.

- Negative strides: ``np.flip`` and slicing with a negative step return a copy.

- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.

- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays.

There are other features for which we do not support tracing and we gracefully
fallback to NumPy for their execution:

- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.

- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``.

- ``ndarray`` subclasses.

- Masked arrays.

- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).

- Sorting / ordering ``complex64/complex128`` arrays.

- NumPy ``np.poly1d`` and ``np.polynomial``.

- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).

- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.

- ``ndarray.ctypes`` attribute.

Can I execute NumPy code on CUDA via ``torch.compile``?
-------------------------------------------------------

Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")``
context. Consider the example

.. code-block:: python
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
Z = numpy_fn(X, Y)
In this example, ``numpy_fn`` will be executed in CUDA. For this to be
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
executing this function several times in the same program run, we may want
to avoid all these rather expensive memory copies. To do so, we just need
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:

.. code-block:: python
@torch.compile
def numpy_fn(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
X, Y = X.numpy(), Y.numpy()
Z = np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
return torch.from_numpy(Z)
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
By doing this, we explicitly create the tensors in CUDA memory, and we keep
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler
but no real data movement happens. Note that the original program would not run
on eager mode now. If you want to run it in eager mode, you would need to call
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning
``Z``. Of course, doing this would execute the program on eager mode NumPy, and
on CPU.


How do I debug NumPy code under ``torch.compile``?
--------------------------------------------------

Debugging JIT compiled code is challenging, given the complexity of modern
compilers and the daunting errors that they raise.
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
contains a few tips and tricks on how to tackle this task.

If the above is not enough to pinpoint the origin of the issue, there are still
a few other NumPy-specific tools we can use. We can discern whether the bug
is entirely in the PyTorch code by disabling tracing through NumPy functions:


.. code-block:: python
from torch._dynamo import config
config.trace_numpy = False
If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``)
using PyTorch as a backend by importing ``import torch._numpy as np``.
This should just be used for **debugging purposes** and is in no way a
replacement for the PyTorch API, as it is **much less performant** and, as a
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to
transform NumPy code into Pytorch code. It is rather easy to read and modify,
so if you find any bug in it feel free to submit a PR fixing it or simply open
an issue.

If the program does work when importing ``torch._numpy as np``, chances are
that the bug is in TorchDynamo. If this is the case, please feel open an issue
with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__.

I ``torch.compile`` some NumPy code and I did not see any speed-up.
-------------------------------------------------------------------

The best place to start is the
`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.

Some graph breaks may happen because of the use of unsupported features. See
:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
that some widely used NumPy features do not play well with compilers. For
example, in-place modifications make reasoning difficult within the compiler and
often yield worse performance than their out-of-place counterparts.As such, it is best to avoid
them. Same goes for the use of the ``out=`` parameter. Instead, prefer
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
for data-dependent ops like masked indexing through boolean masks, or
data-dependent control flow like ``if`` or ``while`` constructions.


Which API to use for fine grain tracing?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 1c85031

Please sign in to comment.