Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs for torch.compile(numpy) #109710

Closed
wants to merge 8 commits into from
147 changes: 145 additions & 2 deletions docs/source/torch.compiler_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
CUDA graphs with Triton are enabled by default in inductor but removing
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.

``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Applying a ``torch.func`` transform to a function that uses ``torch.compile``
does not work:
Expand Down Expand Up @@ -528,6 +528,149 @@ invokes an ``nn.Module``. This is because the outputs now depend on the
parameters of the ``nn.Module``. To get this to work, use
``torch.func.functional_call`` to extract the module state.

Does NumPy work with ``torch.compile``?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Starting in 2.1., ``torch.compile`` understands native NumPy programs that
lezcano marked this conversation as resolved.
Show resolved Hide resolved
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.

.. _nonsupported-numpy-feats:

Which NumPy features does ``torch.compile`` support?
----------------------------------------------------

NumPy within ``torch.compile`` follows the latest NumPy release.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have some way of getting the version of NumPy that torch.compile works / is tested against?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. This is more of something we would like to have, and a rule for people to submit issues / for us to answer issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, can you file an issue to follow-up on this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no issue to follow up on here. I think we should say that we support the last NumPy release, or I can change it to "the same NumPy release that PyTorch supports". I don't think we have any way to get this, but I remember it's written somewhere. @albanD knows where IIRC. I can change the wording and put a link to the relevant place where we track this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is the following:

    1. clearly we are encoding some NumPy behavior based on some version of NumPy, and as written it seems like this can differ from the installed version of NumPy.
    1. "Latest" as a descriptor isn't super helpful for the user - they could be using an old version of PyTorch or NumPy could have just updated between PyTorch releases.
    1. If we know what 1. is, we should just tell the user (in code!) so they can check and do what they will with that information, such as installing a NumPy version to match or upgrading PyTorch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iii, do we currently have a way to retrieve the NumPy version we support when using it in eager or when running the tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano isn't the north star story here be that PyTorch runtime should match the behavior of the currently installed numpy version (with obviously only a subset of versions being properly supported and some leading to nice unsupported errors).
As of today we're basically targetting a given version of Numpy 2.X right? And I would agree that we should make sure that we have CI running with it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at the moment we are implementing the last version of NumPy (pretty much, modulo that NEP50 point). Then I guess we'll be able to add support for NumPy 2.0 (while keeping support for 1.X as well) once we support it in core.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good. I would just not say "the last version of NumPy" but "the Numpy 2.0 pre-release" so that it doesn't become stale if we don't update it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes complete sense. Updating


Generally, ``torch.compile`` is able to trace through most NumPy constructions,
and when it cannot, it fallbacks to eager and lets NumPy execute that piece of
lezcano marked this conversation as resolved.
Show resolved Hide resolved
code. Even then, there are a few features where ``torch.compile`` semantics
slightly deviate from those of NumPy:

- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
a 0-D array under ``torch.compile``. For performance, it is best to use this 0-D
array. If this is not possible, one may often recover the original behavior
lezcano marked this conversation as resolved.
Show resolved Hide resolved
by cast the NumPy scalar to the relevant Python scalar (e.g., ``float``).
lezcano marked this conversation as resolved.
Show resolved Hide resolved
- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a question for this documentation, but why don't we fall back in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we found that the pattern x[::-1].fn() where fn is an out-of-place op is relatively common (it's in torchbench even) so I think it's worth this small dissonance. If people don't agree, we can always reconsider and fall back, or even consider implementing this ourselves, as inductor does support negative strides.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider inverting

Suggested change
- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
- Negative strides: views that result in negative strides will instead copy, e.g. ``np.flip`` or slicing with a negative step.

- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.
- ``{tril,triu}_indices_from`` return arrays rather than lists of tuples.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved

There are other feature for which we do not support tracing and we gracefully
lezcano marked this conversation as resolved.
Show resolved Hide resolved
fallback to NumPy for their execution:

- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
- Long dtypes like ``np.float128`` and ``np.complex256``.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
- ``ndarray`` subclasses.
- Masked arrays.
- Ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]``, ``np.add.reduce``, etc.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
- Fortran orders and, in general, any ``order=`` different to ``C``.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
- Sorting / ordering ``complex64/complex128`` arrays.
- NumPy ``np.poly1d`` and ``np.polynomial``.
- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).
- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.
- ``ndarray.ctypes`` attribute not supported.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Can I execute NumPy code on CUDA via ``torch.compile``?
-------------------------------------------------------

Yes you can! To do so, you may simply execute your code under ``torch.device("cuda")``
lezcano marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

@torch.compile
def numpy_fn(X: ndarray, Y: ndarray): -> ndarray
# Compute the ndarray Z here
return Z
lezcano marked this conversation as resolved.
Show resolved Hide resolved


X = np.random.randn(1000, 1000)
Y = np.random.randn(1000, 1000)
with torch.device("cuda"):
Z = numpy_fn(X, Y)

In this example, ``numpy_fn`` will be executed in CUDA. For this to be
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
executing this function several times in the same program run, we may want
to avoid all these rather expensive memory copies. To do so, we just need
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:

.. code-block:: python

@torch.compile
def numpy_fn(X: Tensor, Y: Tensor): -> Tensor
X = X.numpy()
Y = X.numpy()
# Compute Z here
Z = torch.from_numpy(Z)
return Z

X = torch.randn(1000, 1000, device="cuda")
Y = torch.randn(1000, 1000, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)

By doing this, we explicitly create the tensors in CUDA memory, and we keep
them there. Note that the original program would not run on eager mode now.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
If you want to run it in eager mode, you would need to call ``.numpy(force=True)``
and perhaps doing ``Z = Z.cuda()`` before returning ``Z``. Of course, doing
lezcano marked this conversation as resolved.
Show resolved Hide resolved
this would execute the program on eager mode NumPy, and on CPU.


How do I debug my ``torch.compile``d NumPy code?
------------------------------------------------

Debugging JIT compiled code is challenging, given the complexity of modern
compilers and the daunting errors that they raise.
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
contains a few tips and tricks on how to tackle this task.

If the above is not enough to pinpoint the origin of the issue, there are still
a few other NumPy specific tools we can use. If we find a bug that is blocking
our development and we are tracing through mixed PyTorch <> NumPy code where
the NumPy part is not particularly bulky, we can simply deactivate the NumPy
tracing altogether by doing
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

from torch._dynamo import config config.trace_numpy = False
lezcano marked this conversation as resolved.
Show resolved Hide resolved

This moves back to the behavior in 2.0 and will avoid tracing through any NumPy
function.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Of course, this is not a satisfactory answer if our program is mostly composed
lezcano marked this conversation as resolved.
Show resolved Hide resolved
of NumPy code. In these cases, we can try to execute eagerly (without
lezcano marked this conversation as resolved.
Show resolved Hide resolved
``torch.compile``) the NumPy code on PyTorch by importing ``import torch._numpy
as np``. This should just be used for **debugging purposes** and is in no way a
replacement for the PyTorch API, as it is **much less performant** and, as a
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
Python implementation of NumPy in terms of PyTorchand it is used internally to
lezcano marked this conversation as resolved.
Show resolved Hide resolved
transform NumPy code into Pytorch code.. It is rather easy to read and modify,
lezcano marked this conversation as resolved.
Show resolved Hide resolved
so if you find any bug in it feel free to submit a PR fixing it!
lezcano marked this conversation as resolved.
Show resolved Hide resolved

If the program does work when importing ``torch._numpy as np``, chances are
that the bug is in TorchDynamo. If this is the case, please feel open an issue
with a minimal reproducer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to docs on running the minifier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether the minifier works with NumPy, but sure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds worth checking. If it doesn't then open an issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but let me punt on this for now and do it later in the week, as I just don't have the bandwidth now.


I ``torch.compile`d a NumPy function and I did not see any speed-up.
--------------------------------------------------------------------

The best place to start is the
`tutorial with general advice for how to debug this sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Some graph breaks may happen because of the use of unsupported features. See
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: both of these features exist in PT as well, so I don't see why this advice should exist in NumPy-specific portion of the documentation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intuition behind this (and something that I tried to follow both in post and the docs) is that the reader may be someone that comes from NumPy and does not necessarily have too much PyTorch experience. As such, I tried to crossref quite a bit and avoid assuming too much PyTorch knowledge.

:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
that some widely used NumPy features do not play well with compilers. For
example, in-place modifications make reasoning difficult, so the compiler
removes them in a pass called "functionalization". As such, it is best to avoid
in-place ops, or the use of the ``out=`` parameter, and instead simply use
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
lezcano marked this conversation as resolved.
Show resolved Hide resolved
for data-dependent ops like masked indexing through boolean masks, or
data-dependent control flow like `if` or `while` constructions.


Which API to use for fine grain tracing?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down