-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix nested indirect indexing case for index_propagation #128378
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128378
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Cancelled Jobs, 3 Unrelated FailuresAs of commit 6bcfd5e with merge base c58d3af ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: df9a4bf82658f0fe1f6d1ed4b821477c1889bdc5 Pull Request resolved: #128378
…pagation" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 38bbf45abec7a8b63e723ecf2bc48d6b3815bfa7 Pull Request resolved: #128378
So, if we still include the bounds check instead of returning
Thinking about whether it makes sense to avoid the bounds check from the codegen pass and use the one from IndexPropagation. |
torch/_inductor/index_propagation.py
Outdated
# a = ops.indirect_indexing(...) | ||
# b = ops.index_expr(a, ...) | ||
# c = ops.indirect_indexing(b, ...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, it doesn't seem like we can meaningfully propagate a sympy expression, so don't we want to fallback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we should be able to propagate the sympy expression without issue. We should just fix the _maybe_evaluate_static
call so it can handle indirect0
.
Ideally we would keep track of the bounds for ops.indirect_indexing
since these do have known bounds. But at the very least we should treat it as unbounded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example
a = ops.indirect_indexing(...)
b = ops.index_expr(a, torch.int32)
c = ops.add(b, ops.constant(1, torch.int32))
d = ops.indirect_indexing(c, ...)
d = indirect1
even though we know that indirect1 = indirect0 + 1
We can simplify to:
a = ops.indirect_indexing(...)
d = a + 1
where d = indirect0 + 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do this we should:
- Make sure that the shapes of the tow indirect_indexings are the same
- Make sure we are not wrapping the variable twice (I think we may be currently wrapping it twice).
It's not clear to me whether we can do the first thing in the IndexPropagation pass tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should propagate the bounds from the first indirect0
into the reasoning, so we might be able to elide the wrapping.
In general though, we might need to have two wrappings and two different sizes because each ops.indirect_indexing
call might have come from a different tensor which may have different shapes and each are wrapped independently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indirect0 variables barely ever have bounds, as loads do not have bounds, and indirect indexing is often performed on information that comes from a load.
But yeah, I guess that in general it makes sense for the first indirect to be wrapped and so on at a codegen level, and then we can have the second indirect wrapped here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indirect0 variables barely ever have bounds,
On the contrary, indirect0
always has a bound. ops.indirect_indexing(var, size)
has the bound [0, size) as checked by the runtime assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh, right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use wrap_expr
, but everything else SGTM
torch/_inductor/index_propagation.py
Outdated
# a = ops.indirect_indexing(...) | ||
# b = ops.index_expr(a, ...) | ||
# c = ops.indirect_indexing(b, ...) | ||
return Where(expr < 0, expr + size, expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use wrap_expr
to cover the (rare) case when we can do better than wrapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we still need to emit the check_bounds here unless we can prove that the sizes are the same, no? And in that case, we probably also want to emit it and just CSE it at codegen level.
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach We can probably skip the check to prove `indirect0`'s bounds because its purpose is to add a bounds check as a device side assert. Thankfully, we already do this in the codegen pass. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/codegen/common.py#L1730-L1733 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 2695b4716a8bec7087b5e27fceb1986af97f009a Pull Request resolved: #128378
We can't prove the sizes are the same i.e. s0 == 32. So, now we'll emit another check_bounds() which gives us two of them:
IMO the double check_bounds on |
torch/_inductor/index_propagation.py
Outdated
# b = ops.index_expr(a, ...) | ||
# c = ops.indirect_indexing(b, ...) | ||
indirect_var_to_default_range = tuple( | ||
(sym, self.shape_env._default_unspecified_value_range()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, in the pattern above, you are going to hit this function twice. The first one you will not be able to remove the indirect_indexing
, but you will be able to set proper ranges for it. Then, in the second one, the ranges will already be set, and you will be able to just use those ranges.
The invariant here is that, if you fallback, you will be able to get the result of the fallback, (which is a sympy.Symbol) and add its range as [-size, size-1]
(as it's not wrapped at this point). Then any call after it should work as expected.
You can even assert that the returned symbol is not in var_to_range
before adding it.
I realize that pytorch/torch/_inductor/utils.py Lines 596 to 605 in 3bc2004
So, I tossed that assumption out since |
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach Add `indirectX symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 7199ae4a66ff5132e40b4172229c65a4cfd08e66 Pull Request resolved: #128378
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach Add `indirectX symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: e242befc8ff5528c1431eb9b1e7d1267db47869b Pull Request resolved: #128378
torch/_inductor/index_propagation.py
Outdated
if indirect_var not in self.var_to_range: | ||
lower, upper = -upper_bound(size), upper_bound(size) - 1 | ||
indirect_range = (indirect_var, ValueRanges(lower, upper)) | ||
self.var_to_range = self.var_to_range + (indirect_range,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be turned into
if indirect_var not in self.var_to_range: | |
lower, upper = -upper_bound(size), upper_bound(size) - 1 | |
indirect_range = (indirect_var, ValueRanges(lower, upper)) | |
self.var_to_range = self.var_to_range + (indirect_range,) | |
assert indirect_var not in self.var_to_range | |
lower, upper = -upper_bound(size), upper_bound(size) - 1 | |
indirect_range = (indirect_var, ValueRanges(lower, upper)) | |
self.var_to_range = self.var_to_range + (indirect_range,) |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach largely looks good. Just one question.
Also, it's a pity that, in general, we have to drop the is_positive
assumption from SymPy, but it is true that it was incorrect as written.
torch/_inductor/ir.py
Outdated
from torch.utils._sympy.symbol import make_symbol | ||
|
||
# Note: indirect index variables can be negative, positive or 0. | ||
var = make_symbol(SymT.INDIRECT, len(self.indirect_vars), integer=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can an indirect variable be negative? The wrapping of negative values should happen before the real value is assigned to the symbolic value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right, we do the wrapping within indirect_indexing
in codegen, so I guess the positive condition is correct here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be "nonnegative" since we can have zeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, I'll change that back because we know it must be non-negative after wrapping it.
wrapping of negative values should happen before the real value is assigned to the symbolic value
Btw I'm trying to understand what this looks like. Can you help me understand? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the codegen for indirect_index in common.py
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 573b4ef35b0ead2fc080497050d44cc317d5b141 Pull Request resolved: #128378
…pagation" Tries to fix #127677. # Context Just as peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 469e53f432193ad75c0aa035405610ae3a040c4b Pull Request resolved: #128378
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thank you!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
@pytorchbot merge -i |
…ytorch#128378) Tries to fix pytorch#127677. # Context Just as @peterbell10 pointed out, we have the following scenario: ``` a = ops.indirect_indexing(...) b = ops.index_expr(a, ...) c = ops.indirect_indexing(b, ...) ``` We can repro this as: ``` def forward(self, arg0_1, arg1_1, arg2_1): iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0), repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1); index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]); index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]); return (index_1,) ``` which should generate a JIT py file like this: ``` def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): ... tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp1 = ks0 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) # check_bounds() tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0") def call(): arg0_1, arg1_1, arg2_1 = args buf1 = aten.repeat_interleave.Tensor(arg1_1) buf4 = empty_strided_cuda((u0, 64), (64, 1)) triton_poi_fused_index_select_0.run( buf1, arg2_1, buf4, s0, triton_poi_fused_index_select_0_xnumel, grid=grid(triton_poi_fused_index_select_0_xnumel), stream=stream0) ``` # Issue In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`. https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160 When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`. # Approach When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`. Pull Request resolved: pytorch#128378 Approved by: https://github.com/lezcano, https://github.com/peterbell10
Tries to fix #127677.
Context
Just as @peterbell10 pointed out, we have the following scenario:
We can repro this as:
which should generate a JIT py file like this:
Issue
In our
IndexPropagation.indirect_indexing()
call we haveexpr=indirect0
which is spawned inLoopBodyBlock.indirect_indexing()
.pytorch/torch/_inductor/ir.py
Lines 8154 to 8160 in 3b555ba
When we try to see if we can prove its bounds, we fail because
indirect0
isn't invar_ranges
.Approach
When creating
indirect
symbols from fallback, specify its range to be[-size, size -1]
to avoid a lookup error withindirectX
.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang