From 5cd294dbc9889517ff93c5125b911f33d57896fa Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Mon, 10 Apr 2023 13:30:23 -0700 Subject: [PATCH 1/3] dynamic range constraint API This diff adds the ability to specify range constraints on dynamic dimensions. (Previously we only supported declaring a dynamic dimension, which gets the default range `[-2, sympy.oo]`.) One point worth calling out: our initial design called for compound expressions like `lower <= dynamic_dim(x, d) <= upper`. However this seems difficult to support, because of a combination of desugaring and overloading semantics for such compound expressions in Python. Rather than silently doing the wrong thing, we explicitly error in this case and recommend users to specify multiple constraints, which is supported. Differential Revision: [D44847318](https://our.internmc.facebook.com/intern/diff/D44847318/) [ghstack-poisoned] --- test/dynamo/test_export.py | 41 +++++++++++++++++++ torch/_dynamo/eval_frame.py | 51 ++++++++++++++++++++++-- torch/_dynamo/variables/builder.py | 18 ++++++++- torch/fx/experimental/symbolic_shapes.py | 5 ++- 4 files changed, 109 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index cbcd42e94733..ac9e3fc2f167 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2146,6 +2146,47 @@ def my_dyn_fn(a, b, c): else: torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints) + @config.patch(dynamic_shapes=True) + def test_export_dynamic_dim_raise_on_compound_range_constraint(self): + x = torch.ones(6, 4, 4) + with self.assertRaisesRegex(TypeError, "Cannot determine truth value"): + 4 < dynamic_dim(x, 0) <= 6 + + @config.patch(dynamic_shapes=True) + def test_export_dynamic_dim_range_constraint(self): + x = torch.ones(6, 4, 4) + constraints = [ + 4 < dynamic_dim(x, 0), + dynamic_dim(x, 0) <= 6, + ] + + def foo(x): + if x.shape[0] > 3: # ok + return x.sin() + return x.cos() + + torch._dynamo.export( + foo, + x, + constraints=constraints, + aten_graph=True, + tracing_mode="symbolic", + ) + + def bar(x): + if x.shape[0] > 5: # error + return x.sin() + return x.cos() + + with self.assertRaises(ConstraintViolationError): + torch._dynamo.export( + bar, + x, + constraints=constraints, + aten_graph=True, + tracing_mode="symbolic", + ) + @config.patch(dynamic_shapes=True) def test_list_contains(self): def func(x): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 4faf700c0762..4178d6043bfd 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -60,6 +60,12 @@ null_context = contextlib.nullcontext +from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint +from torch.utils._sympy.value_ranges import ValueRanges +import builtins +import sympy + + # See https://github.com/python/typing/pull/240 class Unset(Enum): token = 0 @@ -596,9 +602,48 @@ class directly; instead, use :func:`torch._export.dynamic_dim`. # TODO: We don't need t_id; we can get it off of w_tensor t_id: int dim: int - constraint_range: Optional[ - torch.fx.experimental.symbolic_shapes.StrictMinMaxConstraint - ] + constraint_range: Optional[StrictMinMaxConstraint] + + @property + def _lower(self): + if self.constraint_range is None: + return 2 + else: + return self.constraint_range.vr.lower + + @property + def _upper(self): + if self.constraint_range is None: + return sympy.oo + else: + return self.constraint_range.vr.upper + + def _clone_with_range(self, lower=2, upper=sympy.oo): + lower, upper = builtins.max(lower, self._lower), builtins.min(upper, self._upper) + return Constraint(self.w_tensor, self.t_id, self.dim, StrictMinMaxConstraint(ValueRanges(lower=lower, upper=upper))) + + def __ge__(self, lower): + return self._clone_with_range(lower=lower) + + def __gt__(self, lower): + return self._clone_with_range(lower=lower+1) + + def __le__(self, upper): + return self._clone_with_range(upper=upper) + + def __lt__(self, upper): + return self._clone_with_range(upper=upper-1) + + def __bool__(self): + # NOTE(avik): We do not support compound expressions like a <= x <= b. + # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), + # and moreover, enforces that any overload of __bool__ must return True or False. + # FWIW, sympy also raises TypeError in this case. + raise TypeError( + f"Cannot determine truth value of Constraint. " + "If you are trying to combine Constraints with logical connectives, " + "split it into multiple Constraints instead." + ) def export( diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index a84556d56f0b..d419e967152d 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1167,7 +1167,23 @@ def wrap_to_fake_tensor_and_record( if tx.output.export_constraints: for constraint in tx.output.export_constraints: if constraint.t_id == t_id: - dim2constraint[constraint.dim] = constraint.constraint_range + if constraint.dim in dim2constraint: + other_cr = dim2constraint[constraint.dim] + cr = constraint.constraint_range + if other_cr is None: + dim2constraint[constraint.dim] = cr + elif cr is not None: + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.value_ranges import ValueRanges + import builtins + dim2constraint[constraint.dim] = StrictMinMaxConstraint( + ValueRanges( + builtins.max(cr.vr.lower, other_cr.vr.lower), + builtins.min(cr.vr.upper, other_cr.vr.upper), + ) + ) + else: + dim2constraint[constraint.dim] = constraint.constraint_range dynamic_dims = None constraint_dims = None diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 6ef3b205edd3..de3b1a4e0439 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1737,7 +1737,7 @@ def produce_guards( # TODO: Make this more efficient by binding all the size/stride/offsets # to locals before performing tests on them. - from torch._dynamo.source import TensorPropertySource, TensorProperty + from torch._dynamo.source import TensorPropertySource, TensorProperty, NegateSource # Actual codegen must be delayed as we don't necessarily know what # the symbol mapping is @@ -1903,7 +1903,7 @@ def hint(): if not (c_vr.lower == r.lower and c_vr.upper == r.upper): record_constraint_violation(lambda: ( f"Could not validate constraint {c.render(sources[0])} as " - f"we actually inferred the valid range to be [{vr.lower}, {vr.upper}]." + f"we actually inferred the valid range to be [{r.lower}, {r.upper}]." "This is actually supposed to be impossible to " "trigger right now as we do not refine ranges; maybe you called " "constrain_range manually, or we forgot to update this error message? " @@ -2256,6 +2256,7 @@ def _add_guard(self, expr: "sympy.Expr") -> None: current_loc = TracingContext.get().loc_in_frame # current_loc describes a line in the current frame user_stack = ''.join(traceback.format_list([*frame_summaries, current_loc])) + guard = ShapeGuard(expr, user_stack) expr = LoggingShapeGuardPrinter(self.var_to_sources).doprint(expr) log.warning(f"Adding shape guard {expr} at \n{user_stack}") log.debug("SHAPE GUARD", stack_info=True) From 1b5c542428a2dd9acb59ba640ba8109d11bccbfc Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Mon, 10 Apr 2023 13:36:13 -0700 Subject: [PATCH 2/3] Update on "dynamic range constraint API" This diff adds the ability to specify range constraints on dynamic dimensions. (Previously we only supported declaring a dynamic dimension, which gets the default range `[-2, sympy.oo]`.) One point worth calling out: our initial design called for compound expressions like `lower <= dynamic_dim(x, d) <= upper`. However this seems difficult to support, because of a combination of desugaring and overloading semantics for such compound expressions in Python. Rather than silently doing the wrong thing, we explicitly error in this case and recommend users to specify multiple constraints, which is supported. Differential Revision: [D44847318](https://our.internmc.facebook.com/intern/diff/D44847318/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned] --- torch/_dynamo/eval_frame.py | 9 +++++++-- torch/fx/experimental/symbolic_shapes.py | 1 - 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 4178d6043bfd..2de2c5d39b62 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -620,7 +620,12 @@ def _upper(self): def _clone_with_range(self, lower=2, upper=sympy.oo): lower, upper = builtins.max(lower, self._lower), builtins.min(upper, self._upper) - return Constraint(self.w_tensor, self.t_id, self.dim, StrictMinMaxConstraint(ValueRanges(lower=lower, upper=upper))) + return Constraint( + self.w_tensor, + self.t_id, + self.dim, + StrictMinMaxConstraint(ValueRanges(lower=lower, upper=upper)) + ) def __ge__(self, lower): return self._clone_with_range(lower=lower) @@ -642,7 +647,7 @@ def __bool__(self): raise TypeError( f"Cannot determine truth value of Constraint. " "If you are trying to combine Constraints with logical connectives, " - "split it into multiple Constraints instead." + "you can specify them separately instead." ) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index de3b1a4e0439..f52cd5befa76 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2256,7 +2256,6 @@ def _add_guard(self, expr: "sympy.Expr") -> None: current_loc = TracingContext.get().loc_in_frame # current_loc describes a line in the current frame user_stack = ''.join(traceback.format_list([*frame_summaries, current_loc])) - guard = ShapeGuard(expr, user_stack) expr = LoggingShapeGuardPrinter(self.var_to_sources).doprint(expr) log.warning(f"Adding shape guard {expr} at \n{user_stack}") log.debug("SHAPE GUARD", stack_info=True) From 56f2208f72d33cac0a5b60604df4bb7340c7bdbb Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Tue, 11 Apr 2023 10:12:22 -0700 Subject: [PATCH 3/3] Update on "dynamic range constraint API" This diff adds the ability to specify range constraints on dynamic dimensions. (Previously we only supported declaring a dynamic dimension, which gets the default range `[2, sympy.oo]`.) One point worth calling out: our initial design called for compound expressions like `lower <= dynamic_dim(x, d) <= upper`. However this seems difficult to support, because of a combination of desugaring and overloading semantics for such compound expressions in Python. Rather than silently doing the wrong thing, we explicitly error in this case and recommend users to specify multiple constraints, which is supported. Differential Revision: [D44847318](https://our.internmc.facebook.com/intern/diff/D44847318/) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned] --- torch/_dynamo/eval_frame.py | 27 +++++---------------------- torch/_dynamo/variables/builder.py | 18 ++++-------------- 2 files changed, 9 insertions(+), 36 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 2de2c5d39b62..75a1eb892a91 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -62,7 +62,6 @@ from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges -import builtins import sympy @@ -602,30 +601,14 @@ class directly; instead, use :func:`torch._export.dynamic_dim`. # TODO: We don't need t_id; we can get it off of w_tensor t_id: int dim: int - constraint_range: Optional[StrictMinMaxConstraint] - - @property - def _lower(self): - if self.constraint_range is None: - return 2 - else: - return self.constraint_range.vr.lower - - @property - def _upper(self): - if self.constraint_range is None: - return sympy.oo - else: - return self.constraint_range.vr.upper + # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, ] + constraint_range: StrictMinMaxConstraint def _clone_with_range(self, lower=2, upper=sympy.oo): - lower, upper = builtins.max(lower, self._lower), builtins.min(upper, self._upper) - return Constraint( - self.w_tensor, - self.t_id, - self.dim, - StrictMinMaxConstraint(ValueRanges(lower=lower, upper=upper)) + constraint_range = StrictMinMaxConstraint( + self.constraint_range.vr & ValueRanges(lower=lower, upper=upper) ) + return Constraint(self.w_tensor, self.t_id, self.dim, constraint_range) def __ge__(self, lower): return self._clone_with_range(lower=lower) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d419e967152d..01f7f4d64ee1 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1168,20 +1168,10 @@ def wrap_to_fake_tensor_and_record( for constraint in tx.output.export_constraints: if constraint.t_id == t_id: if constraint.dim in dim2constraint: - other_cr = dim2constraint[constraint.dim] - cr = constraint.constraint_range - if other_cr is None: - dim2constraint[constraint.dim] = cr - elif cr is not None: - from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint - from torch.utils._sympy.value_ranges import ValueRanges - import builtins - dim2constraint[constraint.dim] = StrictMinMaxConstraint( - ValueRanges( - builtins.max(cr.vr.lower, other_cr.vr.lower), - builtins.min(cr.vr.upper, other_cr.vr.upper), - ) - ) + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + dim2constraint[constraint.dim] = StrictMinMaxConstraint( + constraint.constraint_range.vr & dim2constraint[constraint.dim].vr + ) else: dim2constraint[constraint.dim] = constraint.constraint_range