Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap more constraint violation cases to UserError #100897

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 34 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch._dynamo as torchdynamo
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import _export, export
from torch._export import _export, export, dynamic_dim
from torch._export.trace import do_not_use_experimental_export
from torch._export.constraints import constrain_as_size
from torch._export.graph_module import get_export_meta
Expand Down Expand Up @@ -111,15 +111,29 @@ def invalid_size(x):
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
_export(invalid_size, inp)

def invalid_input(x):
def invalid_input_conflict_with_inline_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)

inp = (torch.tensor([6]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
_export(invalid_input_conflict_with_inline_constraints, inp)

def invalid_input_conflict_with_input_constraints(x):
return x + 1

inp = torch.zeros([3])
inp_constraints = [
dynamic_dim(inp, 0) > 5,
]
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
_export(
invalid_input_conflict_with_input_constraints,
(inp,),
constraints=inp_constraints,
)

with self.assertRaisesRegex(torch.utils._sympy.value_ranges.ValueRangeError, "Invalid value 6 for range"):
_export(invalid_input, inp)

def conflicting_constraints(x):
b = x.item()
Expand Down Expand Up @@ -523,6 +537,22 @@ def method1(self, x: torch.Tensor) -> torch.Tensor:

self.assertTrue(torch.allclose(eager_results, exported_results))

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()

with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"trying to get a value out of symbolic int"
):
_ = _export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)


if __name__ == '__main__':
run_tests()
6 changes: 5 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch._logging
from torch._guards import tracing
from torch._utils_internal import signpost_event
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
)
from torch.fx.graph_module import _forward_from_src as original_forward_from_src

from . import config, exc
Expand Down Expand Up @@ -498,6 +501,7 @@ def log_bytecode(prefix, name, filename, line_no, code):
BackendCompilerFailed,
AssertionError,
ConstraintViolationError,
GuardOnDataDependentSymNode,
) as e:
exception_handler(e, code, frame)
raise
Expand Down
34 changes: 23 additions & 11 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
from torch._dynamo.eval_frame import Constraint

import torch.utils._pytree as pytree
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
StrictMinMaxConstraint,
)
from torch._dynamo.exc import UserError, UserErrorType
from torch.fx._compatibility import compatibility
from torch.fx.passes.pass_manager import PassManager
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.value_ranges import ValueRanges, ValueRangeError

Value = Any

Expand Down Expand Up @@ -101,15 +106,22 @@ def _export(
constraints = []

with torch._dynamo.config.patch(dataclasses.asdict(ExportDynamoConfig())): # type: ignore[attr-defined]
gm, _ = torch._dynamo.export(
f,
*args,
aten_graph=True,
tracing_mode="symbolic",
decomposition_table=DECOMP_TABLE,
constraints=constraints,
assume_static_by_default=True,
)
try:
gm, _ = torch._dynamo.export(
f,
*args,
aten_graph=True,
tracing_mode="symbolic",
decomposition_table=DECOMP_TABLE,
constraints=constraints,
assume_static_by_default=True,
)
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
except GuardOnDataDependentSymNode as e:
raise UserError(
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using constrain_as_*(). {str(e)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please link to an exportdb example? Or do you plan to do it separately?

Copy link
Contributor Author

@guangy10 guangy10 May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if we have one for it in the export db. @zhxchen17 do you know which one to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't have one yet. Let's do it separately


flat_args, in_spec = pytree.tree_flatten(args)
out_spec = (
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,7 +2036,7 @@ def create_symbol(

vr = self.var_to_range[sympy_expr]
if val not in vr:
raise RuntimeError(f"{val} not in range [{vr.lower}, {vr.upper}]")
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")

r = sympy_expr
else:
Expand Down