Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/source/export.ir_spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ Some notable attributes of the :class:`torch.export.ExportedProgram` class are:
- ``state_dict`` (``Dict[str, Union[torch.Tensor, torch.nn.Parameter]]``): Data
structure containing the parameters and buffers.
- ``range_constraints`` (``Dict[sympy.Symbol, RangeConstraint]``): For programs
that are exported with data dependent behavior, the metadata on each node will
that are exported with shape/data dependent behavior, the metadata on each node will
contain symbolic shapes (which look like ``s0``, ``i0``). This attribute maps
the symbolic shapes to their lower/upper ranges.
- ``equality_constraints`` (``List[Tuple[InputDim, InputDim]]``): A list of
nodes in the graph and dimensions that have the same shape.

Graph
-----
Expand Down
18 changes: 5 additions & 13 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ serialized.
assertion_dep_token=None,
)
Range constraints: {}
Equality constraints: []

``torch.export`` produces a clean intermediate representation (IR) with the
following invariants. More specifications about the IR can be found
Expand Down Expand Up @@ -329,16 +328,16 @@ run. Such dimensions must be specified by using the
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
Equality constraints: [(InputDim(input_name='arg5_1', dim=0), InputDim(input_name='arg6_1', dim=0))]

Some additional things to note:

* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first
dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and
``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.
``s0`` is a symbol representing that this dimension can be a range
of values.
Here ``s0`` is a symbol representing that ``arg5_1`` dimension 0 and ``arg6_1``
dimension 0 can have a range
of values, but are required to be equal.

* ``exported_program.range_constraints`` describes the ranges of each symbol
appearing in the graph. In this case, we see that ``s0`` has the range
Expand All @@ -348,13 +347,6 @@ Some additional things to note:
`The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_
for an in-depth discussion of this topic.

* ``exported_program.equality_constraints`` describes which dimensions are
required to be equal. Since we specified in the constraints that the first
dimension of each argument is equivalent,
(``dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0)``),
we see in the equality constraints the tuple specifying that ``arg5_1``
dimension 0 and ``arg6_1`` dimension 0 are equal.

(A legacy mechanism for specifying dynamic shapes
involves marking and constraining dynamic dimensions with the
:func:`torch.export.dynamic_dim` API and passing them into :func:`torch.export.export`
Expand Down Expand Up @@ -394,7 +386,7 @@ Input shapes

As mentioned before, by default, ``torch.export`` will trace the program
specializing on the input tensors' shapes, unless a dimension is specified as
dynamic via the :func:`torch.export.dynamic_dim` API. This means that if there
dynamic via the :func:`torch.export.Dim` API. This means that if there
exists shape-dependent control flow, ``torch.export`` will specialize on the
branch that is being taken with the given sample inputs. For example:

Expand Down Expand Up @@ -426,7 +418,7 @@ The conditional of (``x.shape[0] > 5``) does not appear in the
shape of (10, 2). Since ``torch.export`` specializes on the inputs' static
shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic
branching behavior based on the shape of a tensor in the traced graph,
:func:`torch.export.dynamic_dim` will need to be used to specify the dimension
:func:`torch.export.Dim` will need to be used to specify the dimension
of the input tensor (``x.shape[0]``) to be dynamic, and the source code will
need to be :ref:`rewritten <Data/Shape-Dependent Control Flow>`.

Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2332,7 +2332,7 @@ def foo(x, y):

example_inputs = (copy(x), y)
ep = torch._export._export(foo, example_inputs, constraints=constraints)
with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"):
with self.assertRaisesRegex(RuntimeError, "input.*shape.*to be equal to 2"):
ep(torch.randn(3), y)

dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
Expand Down
11 changes: 6 additions & 5 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
em = torch.export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "\\[1\\] to be equal to 4"):
em(x)

def test_not_correct_dim(self):
Expand Down Expand Up @@ -1138,24 +1138,25 @@ def f(x, y):
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
)
with self.assertRaisesRegex(
RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time"
RuntimeError, "expected input arg1_1 to be equal to 5, but got 6"
):
_ = exported(torch.ones(8, 5), 6)

exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
with self.assertRaisesRegex(
RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time"
RuntimeError, "expected input arg1_1 to be equal to 5.0, but got 6.0"
):
_ = exported(torch.ones(7, 5), 6.0)

def test_runtime_assert_for_prm_str(self):

def g(a, b, mode):
return torch.div(a, b, rounding_mode=mode)

inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = torch._export.export(g, inps)
with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"):
with self.assertRaisesRegex(
RuntimeError, "expected input arg2_1 to be equal to trunc, but got floor"
):
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))

Expand Down
20 changes: 11 additions & 9 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def forward(self, x):
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "input arg0_1"):
ep(torch.zeros(2, 7, 3))

self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
self.assertTrue(
torch.allclose(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
)

def test_runtime_assert_multiple_dims(self) -> None:
class M(torch.nn.Module):
Expand All @@ -99,10 +101,10 @@ def forward(self, x, y):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
)

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

with self.assertRaisesRegex(RuntimeError, "Input arg1_1"):
with self.assertRaisesRegex(RuntimeError, "input arg1_1"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

def test_runtime_assert_some_dims_not_specified(self) -> None:
Expand All @@ -123,12 +125,12 @@ def forward(self, x, y):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
)

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
RuntimeError, r"expected input arg1_1.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

Expand All @@ -152,12 +154,12 @@ def forward(self, x, y):
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
RuntimeError, r"expected input arg1_1.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

Expand Down Expand Up @@ -322,7 +324,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x = torch.rand(3, 5)
y = torch.rand(3, 6)
with self.assertRaisesRegex(
RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]"
RuntimeError, r"expected input arg1_1.shape\[1\] to be equal to 5, but got 6"
):
exported(x, y)

Expand Down
1 change: 1 addition & 0 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class ExportedProgram:
graph_module: GraphModule
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
# TODO(avik): remove equality_constraints because it is redundant
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
schema_version: int
example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
101 changes: 83 additions & 18 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import dataclasses
import math
from enum import auto, Enum
from typing import (
Any,
Expand Down Expand Up @@ -516,6 +517,9 @@ def range_constraints(self):
@property
@compatibility(is_backward_compatible=False)
def equality_constraints(self):
"""
NOTE: This property will be removed in the future.
"""
return self._equality_constraints

@property
Expand Down Expand Up @@ -570,7 +574,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
ordered_buffers = tuple(
self.state_dict[name] for name in self.graph_signature.buffers
)
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
self._check_input_constraints(*args)

# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
Expand Down Expand Up @@ -616,7 +620,6 @@ def __str__(self) -> str:
f" {graph_module}\n"
f"Graph signature: {self.graph_signature}\n"
f"Range constraints: {self.range_constraints}\n"
f"Equality constraints: {self.equality_constraints}\n"
)
return string

Expand Down Expand Up @@ -903,24 +906,86 @@ def make_argument_spec(node) -> ArgumentSpec:

def _check_input_constraints(self, *args):
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForConstraintsPass,
_convert_range_to_int,
)

# TODO(zhxchen17) Don't generate a runtime graph on the fly.
_assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph())
for p in self.graph.nodes:
if p.op != "placeholder":
continue
new_p = _assertion_graph.graph.placeholder(p.name)
new_p.meta = p.meta
_assertion_graph.graph.output(())
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
self.range_constraints,
self.equality_constraints,
)(_assertion_graph)
assert _assertion_graph_res is not None
_assertion_graph = _assertion_graph_res.graph_module
_assertion_graph(*args)
def check(cond, msg):
if not cond:
# TODO(avik): maybe add more context, e.g., graph signature
raise RuntimeError(msg)

placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
inputs = [
p
for p, s in zip(placeholders, self.graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
]
n_args, n_inputs = len(args), len(inputs)
check(
n_args == n_inputs,
f"unexpected number of inputs (expected {n_inputs}, got {n_args})",
)
# NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints.
# TODO(avik): remove equality constraints from ExportedProgram
unification_map: Dict[sympy.Symbol, Any] = {}
for arg, p in zip(args, inputs):
p_val = p.meta["val"]
if (
isinstance(p_val, torch.Tensor)
and "tensor_meta" in p.meta
and p.meta["tensor_meta"] is not None
):
p_shape = p.meta["tensor_meta"].shape
check(
isinstance(arg, torch.Tensor),
f"expected input {p.name} to be a tensor, but got {type(arg)}",
)
n_arg_shape, n_p_shape = len(arg.shape), len(p_shape)
check(
n_arg_shape == n_p_shape,
f"unexpected number of dimensions in input {p.name}.shape "
f"(expected {n_p_shape}, got {n_arg_shape})",
)
for j, (arg_dim, p_dim) in enumerate(zip(arg.shape, p_shape)):
if isinstance(p_dim, torch.SymInt):
if p_dim.node.expr in unification_map:
existing_dim = unification_map[p_dim.node.expr]
check(
arg_dim == existing_dim,
f"expected input {p.name}.shape[{j}] to be equal to "
f"{existing_dim}, but got {arg_dim}",
)
else:
unification_map[p_dim.node.expr] = arg_dim
min_val, max_val = _convert_range_to_int(
self.range_constraints[p_dim.node.expr]
)
# NOTE: we allow dimensions to be 0/1 at runtime
if min_val > 2:
check(
arg_dim >= min_val,
f"expected input {p.name}.shape[{j}] to be >= "
f"{min_val}, but got {arg_dim}",
)
if max_val < math.inf:
check(
arg_dim <= max_val,
f"expected input {p.name}.shape[{j}] to be <= "
f"{max_val}, but got {arg_dim}",
)
else:
check(
arg_dim == p_dim,
f"expected input {p.name}.shape[{j}] to be equal to "
f"{p_dim}, but got {arg_dim}",
)
elif isinstance(p_val, (int, float, str)):
check(
type(arg) == type(p_val) and arg == p_val,
f"expected input {p.name} to be equal to {p_val}, but got {arg}",
)

def _validate(self):
from torch._export.verifier import Verifier, verify_exported_program_signature
Expand Down