Skip to content

Commit

Permalink
Update dynamic shapes documentation
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 6da57e6a83233b9404734279df3883aeeb23feb7
Pull Request resolved: #109764
  • Loading branch information
ezyang committed Sep 21, 2023
1 parent 5aae979 commit 0108d0d
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions docs/source/torch.compiler_dynamic_shapes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@ In supporting dynamic shapes, we chose not to support dynamic rank programs, e.g
Abridged public API
-------------------

The eventual plan:
The default dynamic behavior in PyTorch 2.1 is:

- PT2 assumes everything is static by default
- If we recompile because a size changed, we will instead attempt to recompile that size as being dynamic (so we will never recompile because of that size again)
- If you know ahead of time something will be dynamic, you can skip the first recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``
- If you say ``torch.compile(dynamic=True)`` we will attempt to make as much dynamic as possible

Unbacked integers for eager mode:
- If we recompile because a size changed, we will instead attempt to recompile
that size as being dynamic (sizes that have changed are likely to change in
the future). This generalization may fail (e.g., because user code does a
conditional branch on the size in question or missing dynamic shapes support
in PT2). If you are trying to understand why PT2 has overspecialized some
code, run with ``TORCH_LOGS=dynamic`` and look for "eval" entries that say
when guards are added and why.

What we have currently:
- If you know ahead of time something will be dynamic, you can skip the first
recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``.

- You must explicitly opt into dynamic shapes with ``torch._dynamo.config.automatic_dynamic_shapes = True`` or ``torch.compile(dynamic=True)``
- ``torch.compile(dynamic=True)`` proactively attempts to make everything dynamic
- ``torch._dynamo.config.automatic_dynamic_shapes`` will assume everything is
static, but if we recompile because a size varied, the next time we will try
to compile it dynamically
- ``torch._dynamo.mark_dynamic`` works

Use ``TORCH_LOGS=dynamic`` to view more information about what is going on with dynamic shapes.
- If you say ``torch.compile(dynamic=False)``, we will turn off automatic
dynamic shapes on recompiles and always recompile for each distinct size.
Conversely, if you say ``torch.compile(dynamic=True)``, we will try to make
everything as dynamic as possible. This is mostly useful for small
operators; if you try it on a big model it will (1) probably crash PT2 and
(2) run slow for no good reason.

The Guard Model
---------------
Expand Down Expand Up @@ -114,3 +116,10 @@ Naively implemented, this is too restrictive: most PyTorch programs will immedia
- On tensor creation, PyTorch precomputes a lot of data about a tensor; for example, if you use ``empty_strided`` to create a tensor, we will eagerly sort the strides and determine if the tensor is non-overlapping and dense. Sorts produce a lot of guards. However, it is more common to produce a tensor directly with a higher-level API like ``empty``, which is guaranteed to produce a non-overlapping and dense tensor. We modified PyTorch to avoid needlessly recomputing these properties.
- Even if nontrivial compute is needed, sometimes a property is never actually queried at all. Making these precomputed properties lazy allows us to avoid guarding on an unbacked symbolic integer unless it is actually needed.
- The data in an integer tensor is generally not known to be non-negative. However, we provide an API ``constrain_range`` whereby a user can specify that a size is bounded above and below by known limits.

In future versions of PT2 (beyond PT2.1), we will extend our reasoning system
to infer that an unbacked symbolic integer is size-like based on usage. For
example, if you pass the result of an ``.item()`` call to a factory function
like ``torch.empty``, we will automatically infer that the result is a size
(because if it was not, it would fail.) This assumption would get validated
at runtime, raising an error if it was not fulfilled.

0 comments on commit 0108d0d

Please sign in to comment.