Skip to content

Commit

Permalink
Introduce int_oo
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: b6d509ef0eddc17f4c1bc9bb10512d1dd08c9898
Pull Request resolved: #127693
  • Loading branch information
ezyang committed Jun 4, 2024
1 parent d5c4214 commit 0152d9b
Show file tree
Hide file tree
Showing 10 changed files with 612 additions and 81 deletions.
4 changes: 2 additions & 2 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self):
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)}
> Left: {s0: ValueRanges(lower=3, upper=IntInfinity(), is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=IntInfinity(), is_bool=False, is_int=True, is_float=False)}
> Right: {s0: ValueRanges(lower=2, upper=IntInfinity(), is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=IntInfinity(), is_bool=False, is_int=True, is_float=False)}
""",
)
self._replay_and_check(main)
Expand Down
4 changes: 3 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,9 @@ def f(src_tokens):
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
self.assertEqual(len(gm.shape_env.guards), 0)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 2)

@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):
Expand Down
70 changes: 70 additions & 0 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: pt2"]

import itertools
import math
import sys

import sympy
Expand All @@ -19,6 +20,7 @@
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
Expand Down Expand Up @@ -122,6 +124,74 @@ def generate_range(vals):
yield ValueRanges(a1, a2)


class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)

def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)

def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)

def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)

def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult

Expand All @@ -22,9 +23,9 @@ class InputDim(NamedTuple):

def _convert_to_int(val):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
if val in (sympy.oo, int_oo):
return math.inf
if val == -sympy.oo:
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
Expand Down
65 changes: 26 additions & 39 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
Expand Down Expand Up @@ -869,9 +870,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
for N=1.
"""
if min is None:
min = -sys.maxsize - 1
min = -int_oo
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -1380,6 +1381,7 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
'PythonMod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'PowByNatural': operator.pow,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
Expand Down Expand Up @@ -1987,7 +1989,7 @@ def _check_same_range(c, dim):
self._is_dim(dim)
and ("min" in c or "max" in c)
and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2
and dim.max == c.get("max", sys.maxsize - 1)
and dim.max == c.get("max", int_oo)
)

# 1) newly introduced roots
Expand All @@ -2010,7 +2012,7 @@ def _check_same_range(c, dim):
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
c_max = c.get("max", int_oo)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
Expand Down Expand Up @@ -2766,7 +2768,7 @@ def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None,
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1
max = int_oo

if max < min:
raise ValueError(
Expand Down Expand Up @@ -4094,22 +4096,15 @@ def issue_guard(guard: ShapeGuard) -> None:

assert sources
bounds = []
if r.lower != -sympy.oo:
if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the
# default
if not _simplified or r.lower != self._default_value_range().lower:
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the
# upper bound may be sys.maxsize - 1 because we intentionally
# exclude sys.maxsize from our bounds to deal with direct
# == INT_MAX guards, but it's still dumb to actually test it.
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if r.upper not in (sympy.oo, int_oo):
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
Expand All @@ -4121,9 +4116,8 @@ def issue_guard(guard: ShapeGuard) -> None:
constraints = symbol_to_constraints[symbol]
for c in constraints:
if isinstance(c, StrictMinMaxConstraint):
# NB: By default, we have a restrictive range
# 2 <= s0 <= sys.maxsize - 1. But export users generally
# expect to be able to specify nice ranges like [0, oo]
# TODO: With int_oo, I think this condition is a noop
# now
if not (c.vr & self._default_value_range()).issubset(r):
source = sources[0]

Expand Down Expand Up @@ -4196,9 +4190,9 @@ def issue_guard(guard: ShapeGuard) -> None:
# Reason: '_maybe_evaluate_static' may eliminate guards based on the
# refined value ranges.
for sym, vr in self.var_to_range.items():
if vr.lower != -sympy.oo:
if vr.lower not in (-sympy.oo, -int_oo):
self._add_target_expr(sympy.Le(vr.lower, sym))
if vr.upper != sympy.oo:
if vr.upper not in (sympy.oo, int_oo):
self._add_target_expr(sympy.Le(sym, vr.upper))

# Before validating, populate the input of the validator with the
Expand Down Expand Up @@ -4305,9 +4299,16 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
if size_oblivious:
# Clamp values of size-like variables
# NB: discarding the old upper bound in intentional, per
# https://github.com/pytorch/pytorch/pull/123675
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
# TODO: Maybe we should preserve the lower bound here?
# NB: Upper bound the range with some large amount
# (281 TB here was arbitrarily chosen to match addressable
# virtual space on x86_64) so that comparisons against
# sys.maxsize can no-op
var_to_range[x] = ValueRanges(2, 2 ** 48)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)

Expand Down Expand Up @@ -4436,7 +4437,7 @@ def _maybe_evaluate_static(
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
lower is -int_oo or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
Expand Down Expand Up @@ -4686,20 +4687,8 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No
if a in self.var_to_range:
src_bound = self.var_to_range[a]

# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
# But we don't really care that the max bound says we can
# go beyond the maximum integer size, because we aren't
# using bigints anyway. Arguably, ValueRanges should know
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)

def issubset(x, y):
if x.is_int and y.is_int:
return (x & int_range).issubset(y & int_range)
else:
return x.issubset(y)
return x.issubset(y)

# First, refine the value range of a based on the computed value range
# of tgt. This is always OK to do, even if we decide not to do the
Expand Down Expand Up @@ -4850,6 +4839,7 @@ def _smart_symbol_sort(x):
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
# NB: size_hint is int, not sympy.Expr, do not use int_oo here
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
Expand Down Expand Up @@ -4946,15 +4936,12 @@ def trivial_solve(lhs, rhs):
return

# See: Note - On 0/1 specialization
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
# as a sentinel sometimes. Your sizevar isn't going to be
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1)
return ValueRanges(lower, int_oo)

def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
return ValueRanges(-int_oo, int_oo)

@_lru_cache
def _simplify_floor_div(self, expr):
Expand Down
3 changes: 3 additions & 0 deletions torch/fx/passes/runtime_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.fx._utils import lazy_format_graph_code
from torch.fx.experimental.sym_node import SymNode
from torch.fx.graph_module import GraphModule
from torch.utils._sympy.numbers import int_oo

log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
Expand Down Expand Up @@ -366,6 +367,8 @@ def go(node, keypath):
# (refinement should not be necessary once runtime
# asserts cause refinement, but that's NYI)
def convert(s):
if s in (int_oo, -int_oo):
return None
try:
return int(s)
except TypeError:
Expand Down
Loading

0 comments on commit 0152d9b

Please sign in to comment.