Skip to content

Commit

Permalink
Update on "[inductor] Fix nested indirect indexing case for index_pro…
Browse files Browse the repository at this point in the history
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
We can probably skip the check to prove `indirect0`'s bounds because its purpose is to add a bounds check as a device side assert. Thankfully, we already do this in the codegen pass.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/codegen/common.py#L1730-L1733





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
ColinPeppler committed Jun 12, 2024
1 parent 95a5637 commit 7207271
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions torch/_inductor/index_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

import torch
from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.fx.experimental.symbolic_shapes import free_symbols
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.symbol import free_symbol_is_type, SymT
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .utils import generate_assert

Expand Down Expand Up @@ -323,12 +324,17 @@ def indirect_indexing(

expr = sympy.sympify(index.value.expr)

if free_symbol_is_type(expr, SymT.INDIRECT):
# This is the nested indirect indexing case:
# a = ops.indirect_indexing(...)
# b = ops.index_expr(a, ...)
# c = ops.indirect_indexing(b, ...)
return Where(expr < 0, expr + size, expr)
# Handle nested indirect indexing, by providing a default
# range for indirect symbols. A nested example is:
# a = ops.indirect_indexing(...)
# b = ops.index_expr(a, ...)
# c = ops.indirect_indexing(b, ...)
indirect_var_to_default_range = tuple(
(sym, self.shape_env._default_unspecified_value_range())
for sym in free_symbols(expr)
if symbol_is_type(sym, SymT.INDIRECT)
)
self.var_to_range = self.var_to_range + indirect_var_to_default_range

# TODO Perhaps move this logic to the simplify indexing pass
def wrap_expr(expr):
Expand Down

0 comments on commit 7207271

Please sign in to comment.