Skip to content

Commit

Permalink
Wrap more constraint violation cases to UserError (#100897)
Browse files Browse the repository at this point in the history
Cases covered in this PR:
 - Example inputs conflict with input constraints
 - Example inputs conflict with inline constraints
 - Suggest users to use `constrain_as_*()` when trying to export with data-dependent operations

Differential Revision: [D45666627](https://www.internalfb.com/diff/D45666627)

Pull Request resolved: #100897
Approved by: https://github.com/avikchaudhuri
  • Loading branch information
guangy10 authored and pytorchmergebot committed May 9, 2023
1 parent b179d34 commit 0e08a9b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
38 changes: 34 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch._dynamo as torchdynamo
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import _export, export
from torch._export import _export, export, dynamic_dim
from torch._export.trace import do_not_use_experimental_export
from torch._export.constraints import constrain_as_size
from torch._export.graph_module import get_export_meta
Expand Down Expand Up @@ -111,15 +111,29 @@ def invalid_size(x):
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
_export(invalid_size, inp)

def invalid_input(x):
def invalid_input_conflict_with_inline_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)

inp = (torch.tensor([6]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
_export(invalid_input_conflict_with_inline_constraints, inp)

def invalid_input_conflict_with_input_constraints(x):
return x + 1

inp = torch.zeros([3])
inp_constraints = [
dynamic_dim(inp, 0) > 5,
]
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
_export(
invalid_input_conflict_with_input_constraints,
(inp,),
constraints=inp_constraints,
)

with self.assertRaisesRegex(torch.utils._sympy.value_ranges.ValueRangeError, "Invalid value 6 for range"):
_export(invalid_input, inp)

def conflicting_constraints(x):
b = x.item()
Expand Down Expand Up @@ -523,6 +537,22 @@ def method1(self, x: torch.Tensor) -> torch.Tensor:

self.assertTrue(torch.allclose(eager_results, exported_results))

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()

with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"trying to get a value out of symbolic int"
):
_ = _export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)


if __name__ == '__main__':
run_tests()
6 changes: 5 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch._logging
from torch._guards import tracing
from torch._utils_internal import signpost_event
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
)
from torch.fx.graph_module import _forward_from_src as original_forward_from_src

from . import config, exc
Expand Down Expand Up @@ -498,6 +501,7 @@ def log_bytecode(prefix, name, filename, line_no, code):
BackendCompilerFailed,
AssertionError,
ConstraintViolationError,
GuardOnDataDependentSymNode,
) as e:
exception_handler(e, code, frame)
raise
Expand Down
34 changes: 23 additions & 11 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@

import torch.utils._pytree as pytree
from torch._export.pass_base import PassType
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
StrictMinMaxConstraint,
)
from torch._dynamo.exc import UserError, UserErrorType
from torch.fx._compatibility import compatibility
from torch.fx.passes.pass_manager import PassManager
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.value_ranges import ValueRanges, ValueRangeError

Value = Any

Expand Down Expand Up @@ -99,15 +104,22 @@ def _export(
constraints = []

with torch._dynamo.config.patch(dataclasses.asdict(ExportDynamoConfig())): # type: ignore[attr-defined]
gm, _ = torch._dynamo.export(
f,
*args,
aten_graph=True,
tracing_mode="symbolic",
decomposition_table=DECOMP_TABLE,
constraints=constraints,
assume_static_by_default=True,
)
try:
gm, _ = torch._dynamo.export(
f,
*args,
aten_graph=True,
tracing_mode="symbolic",
decomposition_table=DECOMP_TABLE,
constraints=constraints,
assume_static_by_default=True,
)
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
except GuardOnDataDependentSymNode as e:
raise UserError(
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using constrain_as_*(). {str(e)}")

flat_args, in_spec = pytree.tree_flatten(args)
out_spec = (
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,7 +2035,7 @@ def create_symbol(

vr = self.var_to_range[sympy_expr]
if val not in vr:
raise RuntimeError(f"{val} not in range [{vr.lower}, {vr.upper}]")
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")

r = sympy_expr
else:
Expand Down

0 comments on commit 0e08a9b

Please sign in to comment.