Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "[inductor] Fix nested indirect indexing case for index_pro…
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach We can probably skip the check to prove `indirect0`'s bounds because its purpose is to add a bounds check as a device side assert. Thankfully, we already do this in the codegen pass. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/codegen/common.py#L1730-L1733 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
- Loading branch information