Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
14e5d6e
Value range refinement using multi-variate expressions.
ysiraichi Mar 30, 2023
6ef710b
Update on "[WIP] Value range refinement using multi-variate expressio…
ysiraichi Mar 31, 2023
9e8e614
Rebased. on "[WIP] Value range refinement using multi-variate express…
ysiraichi Mar 31, 2023
0fee70a
Update on "[WIP] Value range refinement using multi-variate expressio…
ysiraichi May 30, 2023
0f98cc6
Fix + using sympy_interp. on "[WIP] Value range refinement using mult…
ysiraichi Jun 2, 2023
3069ed8
Temporarily turing translation validation on by default. on "[WIP] Va…
ysiraichi Jun 2, 2023
adb40b0
Fix CI issues. on "[WIP] Value range refinement using multi-variate e…
ysiraichi Jun 7, 2023
a5d1678
Simplify FX graph creation + CI fixes. on "[WIP] Value range refineme…
ysiraichi Jun 8, 2023
5f58d55
Fix lint errors. on "[WIP] Value range refinement using multi-variate…
ysiraichi Jun 8, 2023
f1ffd68
Add comment. on "[WIP] Value range refinement using multi-variate exp…
ysiraichi Jun 8, 2023
04a500f
Fix CI issues. on "[WIP] Value range refinement using multi-variate e…
ysiraichi Jun 8, 2023
5ed8d56
Improve python types. on "[WIP] Value range refinement using multi-va…
ysiraichi Jun 8, 2023
57b4b11
Fix lint + checking for guard suppression. on "[WIP] Value range refi…
ysiraichi Jun 9, 2023
fce21ac
Rebased. on "[WIP] Value range refinement using multi-variate express…
ysiraichi Jun 9, 2023
d1e5137
Add tests. on "[WIP] Value range refinement using multi-variate expre…
ysiraichi Jun 9, 2023
955365e
Add warning logs temporarily. on "[WIP] Value range refinement using …
ysiraichi Jun 9, 2023
58e5430
Update on "[WIP] Value range refinement using multi-variate expressio…
ysiraichi Jun 13, 2023
ee81edd
Revert translation validation by default. on "[WIP] Value range refin…
ysiraichi Jun 14, 2023
cef846a
Fix lint issues. on "[WIP] Value range refinement using multi-variate…
ysiraichi Jun 14, 2023
6f83c75
Fix + comments. on "[WIP] Value range refinement using multi-variate …
ysiraichi Jun 14, 2023
0e91059
Move Z3 CI dependency to first PR. on "[WIP] Value range refinement u…
ysiraichi Jun 16, 2023
f4b8a95
Extend `ValueRangeAnalysis` for Python semantics. on "[WIP] Value ran…
ysiraichi Jun 16, 2023
d72ed8b
Move validator into its own file. on "[WIP] Value range refinement us…
ysiraichi Jun 19, 2023
355e08b
Fix lint issues. on "[WIP] Value range refinement using multi-variate…
ysiraichi Jun 19, 2023
12a0561
Rebased. on "[WIP] Value range refinement using multi-variate express…
ysiraichi Jun 19, 2023
965bdda
Add test for specialization. on "[WIP] Value range refinement using m…
ysiraichi Jun 19, 2023
173064c
Move tests to #103611. on "[WIP] Value range refinement using multi-v…
ysiraichi Jun 20, 2023
3af2467
Handle unknowns correctly. on "[WIP] Value range refinement using mul…
ysiraichi Jun 20, 2023
94ae6a1
Fix success flag on z3.unknown. on "[WIP] Value range refinement usin…
ysiraichi Jun 21, 2023
c1503c3
Skipping high-latency benchmarks. on "Value range refinement using mu…
ysiraichi Jun 24, 2023
5fa2e88
Skipping high-latency TIMM benchmarks. on "Value range refinement usi…
ysiraichi Jun 24, 2023
de98d10
Disable TV for tests that timeout. on "Value range refinement using m…
ysiraichi Jun 26, 2023
b8e8aa1
Fix lint issues. on "Value range refinement using multi-variate expre…
ysiraichi Jun 26, 2023
68b1377
Fix lint issues. on "Value range refinement using multi-variate expre…
ysiraichi Jun 26, 2023
82684ed
Rebased. on "Value range refinement using multi-variate expressions."
ysiraichi Jun 27, 2023
1b6258d
Small fixes. on "Value range refinement using multi-variate expressio…
ysiraichi Jun 27, 2023
583c817
Kill temporary debugging code. on "Value range refinement using multi…
ysiraichi Jun 28, 2023
d98aa5e
Add Z3 to Windows CI dependency. on "Value range refinement using mul…
ysiraichi Jun 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,28 @@ def f(a):
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] < 20""")

def test_guard_upperbound_range_refinement_multivariate(self):
def f(a):
assert a.shape[0] > 5 and a.shape[0] > 12
assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
self.assertExpectedInline(show_guards(tensor), """\
L['a'].size()[1] > L['a'].size()[0]
L['a'].size()[0] > 12""")

def test_guard_lowerbound_range_refinement_multivariate(self):
def f(a):
assert a.shape[0] < 20 and a.shape[0] < 30
assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
self.assertExpectedInline(
show_guards(tensor),
"""\
L['a'].size()[1] < L['a'].size()[0]
L['a'].size()[0] < 20""")

def test_sym_storage_offset(self):
def f(x, y):
return x + y
Expand Down
8 changes: 2 additions & 6 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.value_ranges import PythonValueRangeAnalysis, ValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._traceback import format_frame
from torch._utils_internal import signpost_event

Expand Down Expand Up @@ -3434,18 +3434,14 @@ def simplify_until(expr: sympy.Expr, max_iterations: int = 10) -> sympy.Expr:
):
continue

# Use only univariate functions.
if len(expr.rhs.free_symbols) > 0:
continue

# Update the value range of the left-hand side, if the
# right-hand side provides a better range.
symbol = expr.lhs

vr = self.var_to_range[symbol]
lower, upper = vr.lower, vr.upper

rhs_vr = sympy_interp(ValueRangeAnalysis, self.var_to_range, expr.rhs) # type: ignore[arg-type]
rhs_vr = sympy_interp(PythonValueRangeAnalysis, self.var_to_range, expr.rhs) # type: ignore[arg-type]
lower_guard, upper_guard = self.var_to_guards.get(symbol, (None, None))

# Let's suppose that we have a preexisting range for x [0, 100].
Expand Down
12 changes: 12 additions & 0 deletions torch/utils/_sympy/value_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,15 @@ def floor_ceil(x, fn):
def __getattr__(self, name):
log.warning("unhandled ValueRange op %s", name)
return self.default_handler


# Implements Python semantics for 'ValueRangeAnalysis'.
# Reasoning about guards relies on Python operator semantics.
class PythonValueRangeAnalysis(ValueRangeAnalysis):
def __init__(self):
super().__init__()
self.name = "PythonValueRangeAnalysis"

@staticmethod
def div(a, b):
return PythonValueRangeAnalysis.floordiv(a, b)