Skip to content

Commit

Permalink
Revert "Wrap indirect indexing on CUDA (#105055)"
Browse files Browse the repository at this point in the history
This reverts commit 85c673e.

Reverted #105055 on behalf of https://github.com/peterbell10 due to Causes failure in inductor_torchbench ([comment](#105055 (comment)))
  • Loading branch information
pytorchmergebot committed Aug 22, 2023
1 parent d59a686 commit b282787
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 147 deletions.
57 changes: 0 additions & 57 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7194,63 +7194,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

self.assertTrue(max_live_tensors == 2)

@skipIfRocm
def test_neg_index(self):
def test(fn, inps, has_assert: bool, has_wrapping=True):
for dynamic in (True, False):
fn_opt = torch.compile(dynamic=dynamic)(fn)
code = run_and_get_triton_code(fn_opt, *inps)
self.assertTrue(("tl.where" in code) is has_wrapping)
self.assertTrue(("device_assert" in code) is has_assert)
self.assertEqual(fn(*inps), fn_opt(*inps))

def indirect(a, b):
return a[b - 1]

a = torch.rand(1024, device="cuda")
b = torch.zeros(4, dtype=torch.long, device="cuda")
test(indirect, (a, b), has_assert=True)

def direct(x):
return x[:, -1]

a = torch.rand(1, 64, 32, device="cuda")
test(direct, (a,), has_assert=False, has_wrapping=False)

def flip(a, b):
return a[b]

a = torch.rand(1024, device="cuda")
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
test(flip, (a, b), has_assert=True)

# Constant propagate a constant that's negative
def flip_with_index_constant(a):
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
return a[b]

# Wrapping is constant-folded
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)

# Operation where we can't prove that the index is always positive or negative
def pos_and_neg(a):
b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device="cuda")
return a[b]

# It has wrapping but no assert
test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)

# We currently don't do constant propagation with float constants
def flip_with_index(a):
b = 1.0 * torch.arange(
start=-1, end=-a.numel() - 1, step=-1, device="cuda"
)
b = b.int()
return a[b]

# Constant is propagated as we can prove that the result is always negative.
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)

# See https://github.com/pytorch/pytorch/issues/100348
def test_inductor_detach_view(self):
def fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 0 additions & 6 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,6 @@ def _print_Mod(self, expr):
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr) # type: ignore[attr-defined]

def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))


class PythonPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
Expand Down
6 changes: 0 additions & 6 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,6 @@ class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}L"

def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"

def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
Expand Down
30 changes: 0 additions & 30 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ def _print_floor(self, expr):
def _helper_sqrt(self, expr):
return f"tl.math.sqrt({self.paren(self._print(expr))}.to(tl.float32))"

def _print_Where(self, expr):
c = self.doprint(expr.args[0])
p = self.doprint(expr.args[1])
q = self.doprint(expr.args[2])
return f"tl.where({c}, {p}, {q})"

def _print_Min(self, expr):
nargs = len(expr.args)
if len(expr.args) == 1:
Expand Down Expand Up @@ -1192,8 +1186,6 @@ def mask_loads(self, mask):
self._load_mask = prior

def indirect_indexing(self, var, size, check=True):
# TODO(lezcano) This code should be lifted to codegen/common.py.
# This should be easy, as now CSE variables carry bounds info
class IndirectAssertLine(DeferredLineBase):
def __init__(self, line, var, mask, size_map):
self.var = var
Expand Down Expand Up @@ -1231,28 +1223,6 @@ def __call__(self):
def _new_line(self, line):
return IndirectAssertLine(line, self.var, self.mask, self.size_map)

if var.bounds.lower < 0:
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
# Take the negative part of the bound and add size to it
# Then take union of that and the positive part
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0:
pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos

stm = f"{var} + {self.index_to_str(size)}"
# Mixed negative and non-negative
if var.bounds.upper >= 0:
stm = f"tl.where({var} < 0, {stm}, {var})"
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)

new_var.update_on_args("index_wrap", (var,), {})
var = new_var

generate_assert = (
(check or config.debug_index_asserts)
and config.triton.assert_indirect_indexing
Expand Down
9 changes: 2 additions & 7 deletions torch/_inductor/index_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import torch
from torch._prims_common import is_boolean_dtype, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.functions import FloorDiv, ModularIndexing


@dataclass
Expand Down Expand Up @@ -229,12 +229,7 @@ def inner(*args: Any, **kwargs: Any) -> Union[Any, IndexPropVar]:
def indirect_indexing(
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
) -> Any:
# nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
# for SymPy expressions, so we don't want to repeat idx too much

# indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
if isinstance(index, IndexPropVar) and index.is_symbolic:
# If we are turning a indirect indexing into direct, we need to wrap it.
index = index.value.expr
return index + Where(index >= 0, 0, size)
return index.value.expr
return self.fallback("indirect_indexing", (index, size, check), {}).value
13 changes: 0 additions & 13 deletions torch/utils/_sympy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,6 @@ def eval(cls, base, divisor, modulus):
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)

class Where(sympy.Function):
"""
Good ol' ternary operator
"""

nargs = (3,)

@classmethod
def eval(cls, c, p, q):
if c == sympy.true:
return p
elif c == sympy.false:
return q

class Mod(sympy.Function):
"""
Expand Down
3 changes: 1 addition & 2 deletions torch/utils/_sympy/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom

import torch
from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Where
from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing


# TODO: Dedupe this with SYMPY_INTERP
Expand Down Expand Up @@ -42,7 +42,6 @@ def handlers():
TrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "div",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
Pow: "pow",
Expand Down
42 changes: 16 additions & 26 deletions torch/utils/_sympy/value_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,9 @@ def __contains__(self, x):
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)

def tighten(self, other) -> "ValueRanges":
def tighten(self, other: "ValueRanges"):
"""Given two ValueRanges, returns their intersection"""
return self & other

# Intersection
def __and__(self, other) -> "ValueRanges":
# Some invariants
if other == ValueRanges.unknown():
return self
if self == ValueRanges.unknown():
Expand All @@ -99,16 +96,9 @@ def __and__(self, other) -> "ValueRanges":
range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
return range

# Union
def __or__(self, other) -> "ValueRanges":
if ValueRanges.unknown() in (self, other):
return ValueRanges.unknown()
assert self.is_bool == other.is_bool, (self, other)
if self.is_bool:
range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
else:
range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
return range
# Intersection
def __and__(self, other):
return ValueRanges(lower=max(self.lower, other.lower), upper=min(self.upper, other.upper))

def is_singleton(self) -> bool:
return self.lower == self.upper
Expand Down Expand Up @@ -445,17 +435,6 @@ def sqrt(x):
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.sqrt)

@staticmethod
def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
assert a.is_bool
assert b.is_bool == c.is_bool
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
else:
return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))


class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self):
Expand Down Expand Up @@ -549,6 +528,17 @@ def trunc(x):
def sub(cls, a, b):
return cls.add(a, cls.neg(b))

@staticmethod
def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
assert a.is_bool
assert b.is_bool == c.is_bool
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
else:
return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))

def __getattr__(self, name):
log.debug("unhandled ValueRange op %s", name)
return self.default_handler
Expand Down

0 comments on commit b282787

Please sign in to comment.