diff --git a/docs/source/export.ir_spec.rst b/docs/source/export.ir_spec.rst index 1bfa9b89c4e9..b434d376e1a7 100644 --- a/docs/source/export.ir_spec.rst +++ b/docs/source/export.ir_spec.rst @@ -66,11 +66,9 @@ Some notable attributes of the :class:`torch.export.ExportedProgram` class are: - ``state_dict`` (``Dict[str, Union[torch.Tensor, torch.nn.Parameter]]``): Data structure containing the parameters and buffers. - ``range_constraints`` (``Dict[sympy.Symbol, RangeConstraint]``): For programs - that are exported with data dependent behavior, the metadata on each node will + that are exported with shape/data dependent behavior, the metadata on each node will contain symbolic shapes (which look like ``s0``, ``i0``). This attribute maps the symbolic shapes to their lower/upper ranges. -- ``equality_constraints`` (``List[Tuple[InputDim, InputDim]]``): A list of - nodes in the graph and dimensions that have the same shape. Graph ----- diff --git a/docs/source/export.rst b/docs/source/export.rst index f5b342c19457..9658c3863a66 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -61,7 +61,6 @@ serialized. assertion_dep_token=None, ) Range constraints: {} - Equality constraints: [] ``torch.export`` produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found @@ -329,7 +328,6 @@ run. Such dimensions must be specified by using the assertion_dep_token=None, ) Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)} - Equality constraints: [(InputDim(input_name='arg5_1', dim=0), InputDim(input_name='arg6_1', dim=0))] Some additional things to note: @@ -337,8 +335,9 @@ Some additional things to note: dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and ``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. - ``s0`` is a symbol representing that this dimension can be a range - of values. + Here ``s0`` is a symbol representing that ``arg5_1`` dimension 0 and ``arg6_1`` + dimension 0 can have a range + of values, but are required to be equal. * ``exported_program.range_constraints`` describes the ranges of each symbol appearing in the graph. In this case, we see that ``s0`` has the range @@ -348,13 +347,6 @@ Some additional things to note: `The 0/1 Specialization Problem `_ for an in-depth discussion of this topic. -* ``exported_program.equality_constraints`` describes which dimensions are - required to be equal. Since we specified in the constraints that the first - dimension of each argument is equivalent, - (``dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0)``), - we see in the equality constraints the tuple specifying that ``arg5_1`` - dimension 0 and ``arg6_1`` dimension 0 are equal. - (A legacy mechanism for specifying dynamic shapes involves marking and constraining dynamic dimensions with the :func:`torch.export.dynamic_dim` API and passing them into :func:`torch.export.export` @@ -394,7 +386,7 @@ Input shapes As mentioned before, by default, ``torch.export`` will trace the program specializing on the input tensors' shapes, unless a dimension is specified as -dynamic via the :func:`torch.export.dynamic_dim` API. This means that if there +dynamic via the :func:`torch.export.Dim` API. This means that if there exists shape-dependent control flow, ``torch.export`` will specialize on the branch that is being taken with the given sample inputs. For example: @@ -426,7 +418,7 @@ The conditional of (``x.shape[0] > 5``) does not appear in the shape of (10, 2). Since ``torch.export`` specializes on the inputs' static shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic branching behavior based on the shape of a tensor in the traced graph, -:func:`torch.export.dynamic_dim` will need to be used to specify the dimension +:func:`torch.export.Dim` will need to be used to specify the dimension of the input tensor (``x.shape[0]``) to be dynamic, and the source code will need to be :ref:`rewritten `. diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index c18658d29d80..cce47edcb657 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2332,7 +2332,7 @@ def foo(x, y): example_inputs = (copy(x), y) ep = torch._export._export(foo, example_inputs, constraints=constraints) - with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"): + with self.assertRaisesRegex(RuntimeError, "input.*shape.*to be equal to 2"): ep(torch.randn(3), y) dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y") diff --git a/test/export/test_export.py b/test/export/test_export.py index 6f13ef15270d..ad484c61068e 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -217,7 +217,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes) em = torch.export.export(m, (a,)) x = torch.randn(3, 5) - with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"): + with self.assertRaisesRegex(RuntimeError, "\\[1\\] to be equal to 4"): em(x) def test_not_correct_dim(self): @@ -1138,24 +1138,25 @@ def f(x, y): torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5)) ) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time" + RuntimeError, "expected input arg1_1 to be equal to 5, but got 6" ): _ = exported(torch.ones(8, 5), 6) exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time" + RuntimeError, "expected input arg1_1 to be equal to 5.0, but got 6.0" ): _ = exported(torch.ones(7, 5), 6.0) def test_runtime_assert_for_prm_str(self): - def g(a, b, mode): return torch.div(a, b, rounding_mode=mode) inps = (torch.randn(4, 4), torch.randn(4), "trunc") exported = torch._export.export(g, inps) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"): + with self.assertRaisesRegex( + RuntimeError, "expected input arg2_1 to be equal to trunc, but got floor" + ): _ = exported(torch.randn(4, 4), torch.randn(4), "floor") self.assertTrue(torch.allclose(exported(*inps), g(*inps))) diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 2ee83e78d147..aa6d86358e9d 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -76,10 +76,12 @@ def forward(self, x): dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "input arg0_1"): ep(torch.zeros(2, 7, 3)) - self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) + self.assertTrue( + torch.allclose(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) + ) def test_runtime_assert_multiple_dims(self) -> None: class M(torch.nn.Module): @@ -99,10 +101,10 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) - with self.assertRaisesRegex(RuntimeError, "Input arg1_1"): + with self.assertRaisesRegex(RuntimeError, "input arg1_1"): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -123,12 +125,12 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"expected input arg1_1.shape\[0\] to be equal to 5, but got 2" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -152,12 +154,12 @@ def forward(self, x, y): dim1_y = torch.export.Dim("dim1_y", min=3, max=6) ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"expected input arg1_1.shape\[0\] to be equal to 5, but got 2" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -322,7 +324,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = torch.rand(3, 5) y = torch.rand(3, 6) with self.assertRaisesRegex( - RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]" + RuntimeError, r"expected input arg1_1.shape\[1\] to be equal to 5, but got 6" ): exported(x, y) diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 041e48143ebb..8b18b26707db 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -301,6 +301,7 @@ class ExportedProgram: graph_module: GraphModule opset_version: Dict[str, int] range_constraints: Dict[str, RangeConstraint] + # TODO(avik): remove equality_constraints because it is redundant equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]] schema_version: int example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]] diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f0079f2e5812..94bfb445d080 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -1,5 +1,6 @@ import copy import dataclasses +import math from enum import auto, Enum from typing import ( Any, @@ -516,6 +517,9 @@ def range_constraints(self): @property @compatibility(is_backward_compatible=False) def equality_constraints(self): + """ + NOTE: This property will be removed in the future. + """ return self._equality_constraints @property @@ -570,7 +574,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ordered_buffers = tuple( self.state_dict[name] for name in self.graph_signature.buffers ) - self._check_input_constraints(*ordered_params, *ordered_buffers, *args) + self._check_input_constraints(*args) # NOTE: calling convention is first params, then buffers, then args as user supplied them. # See: torch/_functorch/aot_autograd.py#L1034 @@ -616,7 +620,6 @@ def __str__(self) -> str: f" {graph_module}\n" f"Graph signature: {self.graph_signature}\n" f"Range constraints: {self.range_constraints}\n" - f"Equality constraints: {self.equality_constraints}\n" ) return string @@ -903,24 +906,86 @@ def make_argument_spec(node) -> ArgumentSpec: def _check_input_constraints(self, *args): from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( - _AddRuntimeAssertionsForConstraintsPass, + _convert_range_to_int, ) - # TODO(zhxchen17) Don't generate a runtime graph on the fly. - _assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph()) - for p in self.graph.nodes: - if p.op != "placeholder": - continue - new_p = _assertion_graph.graph.placeholder(p.name) - new_p.meta = p.meta - _assertion_graph.graph.output(()) - _assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass( - self.range_constraints, - self.equality_constraints, - )(_assertion_graph) - assert _assertion_graph_res is not None - _assertion_graph = _assertion_graph_res.graph_module - _assertion_graph(*args) + def check(cond, msg): + if not cond: + # TODO(avik): maybe add more context, e.g., graph signature + raise RuntimeError(msg) + + placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] + inputs = [ + p + for p, s in zip(placeholders, self.graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ] + n_args, n_inputs = len(args), len(inputs) + check( + n_args == n_inputs, + f"unexpected number of inputs (expected {n_inputs}, got {n_args})", + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + # TODO(avik): remove equality constraints from ExportedProgram + unification_map: Dict[sympy.Symbol, Any] = {} + for arg, p in zip(args, inputs): + p_val = p.meta["val"] + if ( + isinstance(p_val, torch.Tensor) + and "tensor_meta" in p.meta + and p.meta["tensor_meta"] is not None + ): + p_shape = p.meta["tensor_meta"].shape + check( + isinstance(arg, torch.Tensor), + f"expected input {p.name} to be a tensor, but got {type(arg)}", + ) + n_arg_shape, n_p_shape = len(arg.shape), len(p_shape) + check( + n_arg_shape == n_p_shape, + f"unexpected number of dimensions in input {p.name}.shape " + f"(expected {n_p_shape}, got {n_arg_shape})", + ) + for j, (arg_dim, p_dim) in enumerate(zip(arg.shape, p_shape)): + if isinstance(p_dim, torch.SymInt): + if p_dim.node.expr in unification_map: + existing_dim = unification_map[p_dim.node.expr] + check( + arg_dim == existing_dim, + f"expected input {p.name}.shape[{j}] to be equal to " + f"{existing_dim}, but got {arg_dim}", + ) + else: + unification_map[p_dim.node.expr] = arg_dim + min_val, max_val = _convert_range_to_int( + self.range_constraints[p_dim.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + check( + arg_dim >= min_val, + f"expected input {p.name}.shape[{j}] to be >= " + f"{min_val}, but got {arg_dim}", + ) + if max_val < math.inf: + check( + arg_dim <= max_val, + f"expected input {p.name}.shape[{j}] to be <= " + f"{max_val}, but got {arg_dim}", + ) + else: + check( + arg_dim == p_dim, + f"expected input {p.name}.shape[{j}] to be equal to " + f"{p_dim}, but got {arg_dim}", + ) + elif isinstance(p_val, (int, float, str)): + check( + type(arg) == type(p_val) and arg == p_val, + f"expected input {p.name} to be equal to {p_val}, but got {arg}", + ) def _validate(self): from torch._export.verifier import Verifier, verify_exported_program_signature