Skip to content

Commit

Permalink
Revert "Introduce int_oo (#127693)"
Browse files Browse the repository at this point in the history
This reverts commit 9cab598.

Reverted #127693 on behalf of https://github.com/clee2000 due to sorry executorch CI is a bit weird regarding pins, I'll make a chat with mergen with the choices of what to do and how it'll affect executorch CI, reverting for now to prevent more divergences in the meantime ([comment](#127693 (comment)))
  • Loading branch information
pytorchmergebot committed Jun 11, 2024
1 parent c9c1fed commit 5d8c7f3
Show file tree
Hide file tree
Showing 19 changed files with 145 additions and 746 deletions.
9 changes: 7 additions & 2 deletions test/dynamo/test_exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def fn(x, shape):
==> (>= 0 s1)
==> (>= 0 s2)
==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
Expand Down Expand Up @@ -286,14 +287,14 @@ def fn(x, shape):
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 0
==> L['shape'][2]: 2
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 0
==> s3: 2
Assertions:
==> (== 0 L['x'].storage_offset())
Expand All @@ -317,6 +318,10 @@ def fn(x, shape):
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
Expand Down
1 change: 1 addition & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,7 @@ def forward(self, pred, x):
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,
Expand Down
12 changes: 6 additions & 6 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9309,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
Expand Down Expand Up @@ -9343,7 +9343,7 @@ def test_shape_env_equal_unbacked(self):
> Left: 2
> Right: 0
==> var_to_range: values don't match.
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Right: {}
""",
)
Expand Down Expand Up @@ -9420,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self):
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
""",
)
self._replay_and_check(main)
Expand Down Expand Up @@ -9458,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self):
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
""",
)
self._replay_and_check(main)
Expand Down
15 changes: 1 addition & 14 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,6 @@ def forward(self, x):
dynamic_shapes={"x": {0: dim_x}},
)

def test_export_slice_maxsize(self):
class Slice(torch.nn.Module):
def forward(self, *args):
return torch.ops.aten.slice.Tensor(*args)

inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
dynamic_shapes=dynamic_shapes,
)

def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -5196,7 +5183,7 @@ def forward(self, x):
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)

def test_disable_forced_specializations_ok(self):
def test_disable_forced_specializations(self):
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards
Expand Down
4 changes: 4 additions & 0 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,10 @@ def forward(self, x):
func, (torch.randn(3, 4),)
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.",
reason="https://github.com/pytorch/pytorch/issues/112622",
)
def test_operator_with_scalar_output(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
Expand Down
11 changes: 0 additions & 11 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,6 @@ def test_size_expressions(self):
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))

def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))

def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
Expand Down
4 changes: 1 addition & 3 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,9 +1201,7 @@ def f(src_tokens):
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 1)
self.assertEqual(len(gm.shape_env.guards), 0)

@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):
Expand Down
70 changes: 0 additions & 70 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Owner(s): ["oncall: pt2"]

import itertools
import math
import sys

import sympy
Expand All @@ -20,7 +19,6 @@
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
Expand Down Expand Up @@ -124,74 +122,6 @@ def generate_range(vals):
yield ValueRanges(a1, a2)


class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)

def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)

def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)

def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)

def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))
Expand Down
9 changes: 1 addition & 8 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,6 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)

ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
Expand All @@ -765,9 +760,7 @@ def slice_forward(

if end_val < start_val:
end_val = start_val
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
elif end_val > sizes[dim]:
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult

Expand All @@ -24,9 +23,9 @@ class InputDim(NamedTuple):

def _convert_to_int(val):
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
if val == sympy.oo:
return math.inf
if val in (-sympy.oo, -int_oo):
if val == -sympy.oo:
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
Expand Down
11 changes: 5 additions & 6 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo

from .schema import ( # type: ignore[attr-defined]
Argument,
Expand Down Expand Up @@ -322,9 +321,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]

def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
if val == sympy.oo:
return math.inf
if val in (-sympy.oo, -int_oo):
if val == -sympy.oo:
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
Expand All @@ -347,9 +346,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val == math.inf:
return int_oo
return sympy.oo
if val == -math.inf:
return -int_oo
return -sympy.oo
return sympy.Integer(val)


Expand Down Expand Up @@ -1827,7 +1826,7 @@ def deserialize(
self.symbol_name_to_range = {}
if symbol_name_to_range:
for k, vr in symbol_name_to_range.items():
lower = vr.lower
lower = int(vr.lower)
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
Expand Down
14 changes: 5 additions & 9 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
SymTypes,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._sympy.numbers import int_oo

from . import config, ir
from .codegen.common import (
Expand Down Expand Up @@ -1428,21 +1427,18 @@ def format_buffers():
vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr):

def is_convertible(s):
if s in (int_oo, -int_oo):
return False
def convert(s):
try:
int(s)
return True
return int(s)
except TypeError:
return False
return None

if is_convertible(vr.lower):
if (lower := convert(vr.lower)) is not None:
self.register_buffer(
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
set_name=True,
)
if is_convertible(vr.upper):
if (upper := convert(vr.upper)) is not None:
self.register_buffer(
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
set_name=True,
Expand Down
Loading

0 comments on commit 5d8c7f3

Please sign in to comment.