Skip to content

Commit

Permalink
Add docs for torch.compile(numpy)
Browse files Browse the repository at this point in the history
ghstack-source-id: 619ce0282316e24bcf26a0664315652c39958dcd
Pull Request resolved: #109710
  • Loading branch information
lezcano committed Sep 20, 2023
1 parent 2c1554a commit 234ddb0
Showing 1 changed file with 158 additions and 2 deletions.
160 changes: 158 additions & 2 deletions docs/source/torch.compiler_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
CUDA graphs with Triton are enabled by default in inductor but removing
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.

``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Applying a ``torch.func`` transform to a function that uses ``torch.compile``
does not work:
Expand Down Expand Up @@ -528,6 +528,162 @@ invokes an ``nn.Module``. This is because the outputs now depend on the
parameters of the ``nn.Module``. To get this to work, use
``torch.func.functional_call`` to extract the module state.

Does NumPy work with ``torch.compile``?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Starting in 2.1, ``torch.compile`` understands native NumPy programs that
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.

.. _nonsupported-numpy-feats:

Which NumPy features does ``torch.compile`` support?
----------------------------------------------------

NumPy within ``torch.compile`` follows the latest NumPy release.

Generally, ``torch.compile`` is able to trace through most NumPy constructions,
and when it cannot, it falls back to eager and lets NumPy execute that piece of
code. Even then, there are a few features where ``torch.compile`` semantics
slightly deviate from those of NumPy:

- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
array. If this is not possible, one may often recover the original behavior
by casting the NumPy scalar to the relevant Python scalar type ``bool/int/float``.

- Negative strides: ``np.flip`` and slicing with a negative step return a copy.

- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.

- ``{tril,triu}_indices_from`` return arrays rather than lists of tuples.

There are other features for which we do not support tracing and we gracefully
fallback to NumPy for their execution:

- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.

- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``.

- ``ndarray`` subclasses.

- Masked arrays.

- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).

- Sorting / ordering ``complex64/complex128`` arrays.

- NumPy ``np.poly1d`` and ``np.polynomial``.

- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).

- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.

- ``ndarray.ctypes`` attribute not supported.

Can I execute NumPy code on CUDA via ``torch.compile``?
-------------------------------------------------------

Yes you can! To do so, you may simply execute your code under ``torch.device("cuda")``

.. code-block:: python
@torch.compile
def numpy_fn(X: ndarray, Y: ndarray): -> ndarray
# Compute the ndarray Z here
return Z
X = np.random.randn(1000, 1000)
Y = np.random.randn(1000, 1000)
with torch.device("cuda"):
Z = numpy_fn(X, Y)
In this example, ``numpy_fn`` will be executed in CUDA. For this to be
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
executing this function several times in the same program run, we may want
to avoid all these rather expensive memory copies. To do so, we just need
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:

.. code-block:: python
@torch.compile
def numpy_fn(X: Tensor, Y: Tensor): -> Tensor
X = X.numpy()
Y = Y.numpy()
# Compute Z here
Z = torch.from_numpy(Z)
return Z
X = torch.randn(1000, 1000, device="cuda")
Y = torch.randn(1000, 1000, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
By doing this, we explicitly create the tensors in CUDA memory, and we keep
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler
but no real data movement happens. Note that the original program would not run
on eager mode now. If you want to run it in eager mode, you would need to call
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning
``Z``. Of course, doing this would execute the program on eager mode NumPy, and
on CPU.


How do I debug my ``torch.compile``d NumPy code?
------------------------------------------------
Debugging JIT compiled code is challenging, given the complexity of modern
compilers and the daunting errors that they raise.
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
contains a few tips and tricks on how to tackle this task.
If the above is not enough to pinpoint the origin of the issue, there are still
a few other NumPy specific tools we can use. If we find a bug that is blocking
our development and we are tracing through mixed PyTorch <> NumPy code where
the NumPy part is not particularly bulky, we can simply deactivate the NumPy
tracing altogether by doing
.. code-block:: python
from torch._dynamo import config config.trace_numpy = False
This moves back to the behavior in 2.0 and will avoid tracing through any NumPy
function.
Of course, this is not a satisfactory answer if our program is mostly composed
of NumPy code. In these cases, we can try to execute eagerly (without
``torch.compile``) the NumPy code on PyTorch by importing ``import torch._numpy as np``.
This should just be used for **debugging purposes** and is in no way a
replacement for the PyTorch API, as it is **much less performant** and, as a
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
Python implementation of NumPy in terms of PyTorch and it is used internally to
transform NumPy code into Pytorch code. It is rather easy to read and modify,
so if you find any bug in it feel free to submit a PR fixing it!

If the program does work when importing ``torch._numpy as np``, chances are
that the bug is in TorchDynamo. If this is the case, please feel open an issue
with a minimal reproducer.

I ``torch.compile``d a NumPy function and I did not see any speed-up.
--------------------------------------------------------------------
The best place to start is the
`tutorial with general advice for how to debug this sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.
Some graph breaks may happen because of the use of unsupported features. See
:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
that some widely used NumPy features do not play well with compilers. For
example, in-place modifications make reasoning difficult, so the compiler
removes them in a pass called "functionalization". As such, it is best to avoid
in-place ops, or the use of the ``out=`` parameter, and instead simply use
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
for data-dependent ops like masked indexing through boolean masks, or
data-dependent control flow like ``if`` or ``while`` constructions.


Which API to use for fine grain tracing?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 234ddb0

Please sign in to comment.