Skip to content

Commit

Permalink
Update on "[inductor] Fix nested indirect indexing case for index_pro…
Browse files Browse the repository at this point in the history
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
ColinPeppler committed Jun 13, 2024
1 parent 994586b commit 6bcfd5e
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions torch/_inductor/index_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@ def wrap_expr(expr):
assert (
indirect_var not in self.var_to_range
), f"{indirect_var} should've been created in the fallback."

lower, upper = -upper_bound(size), upper_bound(size) - 1
indirect_range = (indirect_var, ValueRanges(lower, upper))
indirect_range = (indirect_var, ValueRanges(0, upper_bound(size) - 1))
self.var_to_range = self.var_to_range + (indirect_range,)
return indirect_var

0 comments on commit 6bcfd5e

Please sign in to comment.