Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,47 @@ def my_dyn_fn(a, b, c):
else:
torch._dynamo.export(my_dyn_fn, x, y, z, constraints=constraints)

@config.patch(dynamic_shapes=True)
def test_export_dynamic_dim_raise_on_compound_range_constraint(self):
x = torch.ones(6, 4, 4)
with self.assertRaisesRegex(TypeError, "Cannot determine truth value"):
4 < dynamic_dim(x, 0) <= 6

@config.patch(dynamic_shapes=True)
def test_export_dynamic_dim_range_constraint(self):
x = torch.ones(6, 4, 4)
constraints = [
4 < dynamic_dim(x, 0),
dynamic_dim(x, 0) <= 6,
]

def foo(x):
if x.shape[0] > 3: # ok
return x.sin()
return x.cos()

torch._dynamo.export(
foo,
x,
constraints=constraints,
aten_graph=True,
tracing_mode="symbolic",
)

def bar(x):
if x.shape[0] > 5: # error
return x.sin()
return x.cos()

with self.assertRaises(ConstraintViolationError):
torch._dynamo.export(
bar,
x,
constraints=constraints,
aten_graph=True,
tracing_mode="symbolic",
)

@config.patch(dynamic_shapes=True)
def test_list_contains(self):
def func(x):
Expand Down
39 changes: 36 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
null_context = contextlib.nullcontext


from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.value_ranges import ValueRanges
import sympy


# See https://github.com/python/typing/pull/240
class Unset(Enum):
token = 0
Expand Down Expand Up @@ -596,9 +601,37 @@ class directly; instead, use :func:`torch._export.dynamic_dim`.
# TODO: We don't need t_id; we can get it off of w_tensor
t_id: int
dim: int
constraint_range: Optional[
torch.fx.experimental.symbolic_shapes.StrictMinMaxConstraint
]
# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: StrictMinMaxConstraint

def _clone_with_range(self, lower=2, upper=sympy.oo):
constraint_range = StrictMinMaxConstraint(
self.constraint_range.vr & ValueRanges(lower=lower, upper=upper)
)
return Constraint(self.w_tensor, self.t_id, self.dim, constraint_range)

def __ge__(self, lower):
return self._clone_with_range(lower=lower)

def __gt__(self, lower):
return self._clone_with_range(lower=lower+1)

def __le__(self, upper):
return self._clone_with_range(upper=upper)

def __lt__(self, upper):
return self._clone_with_range(upper=upper-1)

def __bool__(self):
# NOTE(avik): We do not support compound expressions like a <= x <= b.
# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
# and moreover, enforces that any overload of __bool__ must return True or False.
# FWIW, sympy also raises TypeError in this case.
raise TypeError(
f"Cannot determine truth value of Constraint. "
"If you are trying to combine Constraints with logical connectives, "
"you can specify them separately instead."
)


def export(
Expand Down
8 changes: 7 additions & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,13 @@ def wrap_to_fake_tensor_and_record(
if tx.output.export_constraints:
for constraint in tx.output.export_constraints:
if constraint.t_id == t_id:
dim2constraint[constraint.dim] = constraint.constraint_range
if constraint.dim in dim2constraint:
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
dim2constraint[constraint.dim] = StrictMinMaxConstraint(
constraint.constraint_range.vr & dim2constraint[constraint.dim].vr
)
else:
dim2constraint[constraint.dim] = constraint.constraint_range

dynamic_dims = None
constraint_dims = None
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ def produce_guards(
# TODO: Make this more efficient by binding all the size/stride/offsets
# to locals before performing tests on them.

from torch._dynamo.source import TensorPropertySource, TensorProperty
from torch._dynamo.source import TensorPropertySource, TensorProperty, NegateSource

# Actual codegen must be delayed as we don't necessarily know what
# the symbol mapping is
Expand Down Expand Up @@ -1903,7 +1903,7 @@ def hint():
if not (c_vr.lower == r.lower and c_vr.upper == r.upper):
record_constraint_violation(lambda: (
f"Could not validate constraint {c.render(sources[0])} as "
f"we actually inferred the valid range to be [{vr.lower}, {vr.upper}]."
f"we actually inferred the valid range to be [{r.lower}, {r.upper}]."
"This is actually supposed to be impossible to "
"trigger right now as we do not refine ranges; maybe you called "
"constrain_range manually, or we forgot to update this error message? "
Expand Down