Skip to content

Commit

Permalink
Introduce int_oo
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 50bb181cf53418b172ff45389036d2cd55086880
Pull Request resolved: #127693
  • Loading branch information
ezyang committed Jun 3, 2024
1 parent 4c981de commit 50514ce
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 21 deletions.
46 changes: 46 additions & 0 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
Expand Down Expand Up @@ -122,6 +123,51 @@ def generate_range(vals):
yield ValueRanges(a1, a2)


class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo)
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))
Expand Down
25 changes: 14 additions & 11 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
Expand Down Expand Up @@ -869,9 +870,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1.
"""
if min is None:
min = -sys.maxsize - 1
min = -int_oo
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -1977,7 +1978,7 @@ def _check_same_range(c, dim):
self._is_dim(dim)
and ("min" in c or "max" in c)
and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2
and dim.max == c.get("max", sys.maxsize - 1)
and dim.max == c.get("max", int_oo)
)

# 1) newly introduced roots
Expand All @@ -2000,7 +2001,7 @@ def _check_same_range(c, dim):
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
c_max = c.get("max", int_oo)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
Expand Down Expand Up @@ -2746,7 +2747,7 @@ def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None,
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -4088,7 +4089,8 @@ def issue_guard(guard: ShapeGuard) -> None:
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
# TODO: update this
if r.upper != sympy.oo and r.upper is not int_oo:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
Expand Down Expand Up @@ -4286,7 +4288,7 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa
# Clamp values of size-like variables
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
var_to_range[x] = ValueRanges(2, int_oo)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)

Expand Down Expand Up @@ -4411,7 +4413,7 @@ def _maybe_evaluate_static(
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
lower is -int_oo or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
Expand Down Expand Up @@ -4665,7 +4667,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
int_range = ValueRanges(-int_oo, int_oo)

def issubset(x, y):
if x.is_int and y.is_int:
Expand Down Expand Up @@ -4822,6 +4824,7 @@ def _smart_symbol_sort(x):
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
Expand Down Expand Up @@ -4923,10 +4926,10 @@ def trivial_solve(lhs, rhs):
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1)
return ValueRanges(lower, int_oo)

def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
return ValueRanges(-int_oo, int_oo)

@_lru_cache
def _simplify_floor_div(self, expr):
Expand Down
22 changes: 14 additions & 8 deletions torch/utils/_sympy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sympy
from sympy import S

from .numbers import int_oo

__all__ = [
"FloorDiv",
"ModularIndexing",
Expand Down Expand Up @@ -399,6 +401,7 @@ def safe_pow(base, exp):
return sign * _safe_pow(base, exp)


# Prevent people from overflowing pow
def _safe_pow(base, exponent):
if exponent < 0:
raise ValueError("Exponent must be non-negative.")
Expand All @@ -407,17 +410,20 @@ def _safe_pow(base, exponent):
return 1

half_exp = safe_pow(base, exponent // 2)
if half_exp > sys.maxsize - 1:
return sys.maxsize - 1
if half_exp is int_oo:
return int_oo

# TODO: microoptimization is to avoid overflowing into arbitrary precision
# and detect overflow prior to doing operations

result = half_exp * half_exp
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo

if exponent % 2 == 1:
result *= base
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo

return result

Expand Down Expand Up @@ -555,9 +561,9 @@ class TruncToInt(sympy.Function):
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
return int_oo
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number)))

Expand Down
Loading

0 comments on commit 50514ce

Please sign in to comment.