Skip to content

Commit

Permalink
Introduce int_oo
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 58a04544490be9a22abfec273277aebca9068e8d
Pull Request resolved: #127693
  • Loading branch information
ezyang committed Jun 3, 2024
1 parent b5c7138 commit 50bd608
Show file tree
Hide file tree
Showing 6 changed files with 587 additions and 48 deletions.
70 changes: 70 additions & 0 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: pt2"]

import itertools
import math
import sys

import sympy
Expand All @@ -19,6 +20,7 @@
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
Expand Down Expand Up @@ -122,6 +124,74 @@ def generate_range(vals):
yield ValueRanges(a1, a2)


class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)

def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)

def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)

def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)

def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))
Expand Down
25 changes: 14 additions & 11 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
Expand Down Expand Up @@ -869,9 +870,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1.
"""
if min is None:
min = -sys.maxsize - 1
min = -int_oo
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -1977,7 +1978,7 @@ def _check_same_range(c, dim):
self._is_dim(dim)
and ("min" in c or "max" in c)
and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2
and dim.max == c.get("max", sys.maxsize - 1)
and dim.max == c.get("max", int_oo)
)

# 1) newly introduced roots
Expand All @@ -2000,7 +2001,7 @@ def _check_same_range(c, dim):
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
c_max = c.get("max", int_oo)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
Expand Down Expand Up @@ -2746,7 +2747,7 @@ def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None,
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -4088,7 +4089,8 @@ def issue_guard(guard: ShapeGuard) -> None:
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
# TODO: update this
if r.upper != sympy.oo and r.upper is not int_oo:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
Expand Down Expand Up @@ -4286,7 +4288,7 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa
# Clamp values of size-like variables
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
var_to_range[x] = ValueRanges(2, int_oo)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)

Expand Down Expand Up @@ -4411,7 +4413,7 @@ def _maybe_evaluate_static(
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
lower is -int_oo or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
Expand Down Expand Up @@ -4665,7 +4667,7 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
int_range = ValueRanges(-int_oo, int_oo)

def issubset(x, y):
if x.is_int and y.is_int:
Expand Down Expand Up @@ -4822,6 +4824,7 @@ def _smart_symbol_sort(x):
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
Expand Down Expand Up @@ -4923,10 +4926,10 @@ def trivial_solve(lhs, rhs):
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1)
return ValueRanges(lower, int_oo)

def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
return ValueRanges(-int_oo, int_oo)

@_lru_cache
def _simplify_floor_div(self, expr):
Expand Down
78 changes: 64 additions & 14 deletions torch/utils/_sympy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sympy
from sympy import S

from .numbers import int_oo

__all__ = [
"FloorDiv",
"ModularIndexing",
Expand Down Expand Up @@ -107,6 +109,20 @@ def eval(cls, base, divisor):
return base
if base.is_integer and divisor == -1:
return sympy.Mul(base, -1)
if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (base in (int_oo, -int_oo) or divisor in (int_oo, -int_oo))
):
r = float(base) / float(divisor)
if r == math.inf:
return int_oo
elif r == -math.inf:
return -int_oo
elif math.isnan(r):
return sympy.nan
else:
return sympy.Integer(math.floor(r))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Integer(int(base) // int(divisor))
if isinstance(base, FloorDiv):
Expand Down Expand Up @@ -399,6 +415,7 @@ def safe_pow(base, exp):
return sign * _safe_pow(base, exp)


# Prevent people from overflowing pow
def _safe_pow(base, exponent):
if exponent < 0:
raise ValueError("Exponent must be non-negative.")
Expand All @@ -407,17 +424,20 @@ def _safe_pow(base, exponent):
return 1

half_exp = safe_pow(base, exponent // 2)
if half_exp > sys.maxsize - 1:
return sys.maxsize - 1
if half_exp is int_oo:
return int_oo

# TODO: microoptimization is to avoid overflowing into arbitrary precision
# and detect overflow prior to doing operations

result = half_exp * half_exp
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo

if exponent % 2 == 1:
result *= base
if result > sys.maxsize - 1:
return sys.maxsize - 1
if result > sys.maxsize:
return int_oo

return result

Expand All @@ -431,14 +451,23 @@ def eval(cls, base, exp):
# have concluded this externally from Sympy assumptions, so we can't
# assert the nonnegative
assert exp.is_integer, exp
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Integer(safe_pow(base, exp))
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
r = safe_pow(base, exp)
if r in (-int_oo, int_oo):
return r
return sympy.Integer(r)
if isinstance(exp, sympy.Integer):
# Translate power into iterated multiplication
# (Rely on this for base is int_oo case too.)
r = sympy.Integer(1)
for _ in range(int(exp)):
r *= base
return r
if exp is int_oo:
if base.is_nonnegative:
return int_oo
elif base.is_negative:
return sympy.zoo # this is apparently what (-2)**sympy.oo does
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
# is a natural number if we do

Expand All @@ -451,6 +480,11 @@ class FloatPow(sympy.Function):

@classmethod
def eval(cls, base, exp):
# NB: These test sympy.Number, not sympy.Float, because:
# - Sometimes we may have sympy.oo or int_oo, and that's not a Float
# (but coerces to math.Inf)
# - Sometimes Float(0.0) will unpredictably decay to Integer(0),
# but we should still accept it in floatey contexts
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
return sympy.Float(float(base) ** float(exp))
# NB: do not do any nontrivial reasoning
Expand Down Expand Up @@ -497,7 +531,15 @@ def eval(cls, base, divisor):
if divisor.is_zero:
raise ZeroDivisionError("division by zero")

if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
if (
isinstance(base, sympy.Number)
and isinstance(divisor, sympy.Number)
and (base in (int_oo, -int_oo) or divisor in (int_oo, -int_oo))
):
# Don't have to worry about precision here, you're getting zero or
# inf from the division
return sympy.Float(float(base) / float(divisor))
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return sympy.Float(int(base) / int(divisor))


Expand Down Expand Up @@ -554,10 +596,10 @@ class TruncToInt(sympy.Function):
@classmethod
def eval(cls, number):
# assert number.is_integer is not True, number
if number == sympy.oo:
return sympy.Integer(sys.maxsize - 1)
if number == -sympy.oo:
return sympy.Integer(-sys.maxsize - 1)
if number in (sympy.oo, int_oo):
return int_oo
if number in (-sympy.oo, -int_oo):
return -int_oo
if isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number)))

Expand Down Expand Up @@ -614,6 +656,10 @@ def eval(cls, number):

if isinstance(number, sympy.Integer):
return sympy.Float(int(number))
if number is int_oo:
return sympy.oo
if number is -int_oo:
return -sympy.oo


def make_opaque_unary_fn(name):
Expand Down Expand Up @@ -644,7 +690,11 @@ def eval(cls, a):
# weird objects but ask silly questions, get silly answers
except OverflowError:
return getattr(sympy, name)(a)
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]:
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
if a is int_oo:
a = sympy.oo
if a is -int_oo:
a = -sympy.oo
return getattr(sympy, name)(a)
return None

Expand Down
24 changes: 16 additions & 8 deletions torch/utils/_sympy/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import functools
import logging
from typing import Any, Dict, Union

import sympy
Expand All @@ -34,6 +35,9 @@
)


log = logging.getLogger(__name__)


# TODO: Dedupe this with SYMPY_INTERP


Expand Down Expand Up @@ -150,11 +154,15 @@ def sympy_interp(
else:
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
return acc
else:
return handler(*args)
try:
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
return acc
else:
return handler(*args)
except Exception:
log.warning("failed while executing %s(%s)", handler_name, args)
raise
Loading

0 comments on commit 50bd608

Please sign in to comment.