diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 30baf1fdae01..abcd12fc3b18 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1344,6 +1344,28 @@ def f(a): tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15)) self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] < 20""") + def test_guard_upperbound_range_refinement_multivariate(self): + def f(a): + assert a.shape[0] > 5 and a.shape[0] > 12 + assert a.shape[1] > 5 and a.shape[1] > a.shape[0] + return a.cos() + tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20))) + self.assertExpectedInline(show_guards(tensor), """\ +L['a'].size()[1] > L['a'].size()[0] +L['a'].size()[0] > 12""") + + def test_guard_lowerbound_range_refinement_multivariate(self): + def f(a): + assert a.shape[0] < 20 and a.shape[0] < 30 + assert a.shape[1] < 30 and a.shape[1] < a.shape[0] + return a.cos() + tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5))) + self.assertExpectedInline( + show_guards(tensor), + """\ +L['a'].size()[1] < L['a'].size()[0] +L['a'].size()[0] < 20""") + def test_sym_storage_offset(self): def f(x, y): return x + y diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3a6548875952..10a8a6616efe 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -34,7 +34,7 @@ ) from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode from torch.utils._sympy.interp import sympy_interp -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges, ValueRangeError +from torch.utils._sympy.value_ranges import PythonValueRangeAnalysis, ValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._traceback import format_frame from torch._utils_internal import signpost_event @@ -3434,10 +3434,6 @@ def simplify_until(expr: sympy.Expr, max_iterations: int = 10) -> sympy.Expr: ): continue - # Use only univariate functions. - if len(expr.rhs.free_symbols) > 0: - continue - # Update the value range of the left-hand side, if the # right-hand side provides a better range. symbol = expr.lhs @@ -3445,7 +3441,7 @@ def simplify_until(expr: sympy.Expr, max_iterations: int = 10) -> sympy.Expr: vr = self.var_to_range[symbol] lower, upper = vr.lower, vr.upper - rhs_vr = sympy_interp(ValueRangeAnalysis, self.var_to_range, expr.rhs) # type: ignore[arg-type] + rhs_vr = sympy_interp(PythonValueRangeAnalysis, self.var_to_range, expr.rhs) # type: ignore[arg-type] lower_guard, upper_guard = self.var_to_guards.get(symbol, (None, None)) # Let's suppose that we have a preexisting range for x [0, 100]. diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index f3500b4d5940..565967cf3a4c 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -469,3 +469,15 @@ def floor_ceil(x, fn): def __getattr__(self, name): log.warning("unhandled ValueRange op %s", name) return self.default_handler + + +# Implements Python semantics for 'ValueRangeAnalysis'. +# Reasoning about guards relies on Python operator semantics. +class PythonValueRangeAnalysis(ValueRangeAnalysis): + def __init__(self): + super().__init__() + self.name = "PythonValueRangeAnalysis" + + @staticmethod + def div(a, b): + return PythonValueRangeAnalysis.floordiv(a, b)