Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix nested indirect indexing case for index_propagation #128378

Closed
wants to merge 7 commits into from

Conversation

ColinPeppler
Copy link
Contributor

@ColinPeppler ColinPeppler commented Jun 11, 2024

Tries to fix #127677.

Context

Just as @peterbell10 pointed out, we have the following scenario:

a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)

We can repro this as:

def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)

which should generate a JIT py file like this:

def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)

Issue

In our IndexPropagation.indirect_indexing() call we have expr=indirect0 which is spawned in LoopBodyBlock.indirect_indexing().

pytorch/torch/_inductor/ir.py

Lines 8154 to 8160 in 3b555ba

def indirect_indexing(index_proxy, size, check=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)

When we try to see if we can prove its bounds, we fail because indirect0 isn't in var_ranges.

Approach

When creating indirect symbols from fallback, specify its range to be [-size, size -1] to avoid a lookup error with indirectX.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Jun 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128378

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Cancelled Jobs, 3 Unrelated Failures

As of commit 6bcfd5e with merge base c58d3af (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ColinPeppler added a commit that referenced this pull request Jun 11, 2024
ghstack-source-id: df9a4bf82658f0fe1f6d1ed4b821477c1889bdc5
Pull Request resolved: #128378
@pytorch-bot pytorch-bot bot added module: inductor release notes: fx release notes category labels Jun 11, 2024
…pagation"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 11, 2024
ghstack-source-id: 38bbf45abec7a8b63e723ecf2bc48d6b3815bfa7
Pull Request resolved: #128378
@ColinPeppler
Copy link
Contributor Author

So, if we still include the bounds check instead of returning expr immediately, we'll have two bounds checks.

tmp4 = tl.where(tmp3, tmp2, tmp0)
# One from codegen pass
tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")
# One from IndexPropagation -- this one seems like the stricter bounds check (but safer?). 
tl.device_assert(((0 <= tmp4) & (tmp4 < 32)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 32")

Thinking about whether it makes sense to avoid the bounds check from the codegen pass and use the one from IndexPropagation.

Comment on lines 328 to 330
# a = ops.indirect_indexing(...)
# b = ops.index_expr(a, ...)
# c = ops.indirect_indexing(b, ...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it doesn't seem like we can meaningfully propagate a sympy expression, so don't we want to fallback?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we should be able to propagate the sympy expression without issue. We should just fix the _maybe_evaluate_static call so it can handle indirect0.

Ideally we would keep track of the bounds for ops.indirect_indexing since these do have known bounds. But at the very least we should treat it as unbounded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example

a = ops.indirect_indexing(...)
b = ops.index_expr(a, torch.int32)
c = ops.add(b, ops.constant(1, torch.int32))
d = ops.indirect_indexing(c, ...)

d = indirect1 even though we know that indirect1 = indirect0 + 1

We can simplify to:

a = ops.indirect_indexing(...)
d = a + 1

where d = indirect0 + 1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do this we should:

  • Make sure that the shapes of the tow indirect_indexings are the same
  • Make sure we are not wrapping the variable twice (I think we may be currently wrapping it twice).

It's not clear to me whether we can do the first thing in the IndexPropagation pass tho.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should propagate the bounds from the first indirect0 into the reasoning, so we might be able to elide the wrapping.

In general though, we might need to have two wrappings and two different sizes because each ops.indirect_indexing call might have come from a different tensor which may have different shapes and each are wrapped independently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indirect0 variables barely ever have bounds, as loads do not have bounds, and indirect indexing is often performed on information that comes from a load.

But yeah, I guess that in general it makes sense for the first indirect to be wrapped and so on at a codegen level, and then we can have the second indirect wrapped here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indirect0 variables barely ever have bounds,

On the contrary, indirect0 always has a bound. ops.indirect_indexing(var, size) has the bound [0, size) as checked by the runtime assert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh, right.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use wrap_expr, but everything else SGTM

# a = ops.indirect_indexing(...)
# b = ops.index_expr(a, ...)
# c = ops.indirect_indexing(b, ...)
return Where(expr < 0, expr + size, expr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use wrap_expr to cover the (rare) case when we can do better than wrapping.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we still need to emit the check_bounds here unless we can prove that the sizes are the same, no? And in that case, we probably also want to emit it and just CSE it at codegen level.

…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
We can probably skip the check to prove `indirect0`'s bounds because its purpose is to add a bounds check as a device side assert. Thankfully, we already do this in the codegen pass.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/codegen/common.py#L1730-L1733





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: 2695b4716a8bec7087b5e27fceb1986af97f009a
Pull Request resolved: #128378
@ColinPeppler
Copy link
Contributor Author

We can't prove the sizes are the same i.e. s0 == 32.

So, now we'll emit another check_bounds() which gives us two of them:

# First check_bounds() from torch.arange(repeats.numel())
tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")
# Second check_bounds() from torch.index_select(x ...)
tl.device_assert((tmp4 < 32) | ~(xmask), "index out of bounds: tmp4 < 32")

IMO the double check_bounds on tmp4 does feel a bit weird. If anything, the second check is all that's necessary, but I'm unsure how to prevent the first check.

# b = ops.index_expr(a, ...)
# c = ops.indirect_indexing(b, ...)
indirect_var_to_default_range = tuple(
(sym, self.shape_env._default_unspecified_value_range())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, in the pattern above, you are going to hit this function twice. The first one you will not be able to remove the indirect_indexing, but you will be able to set proper ranges for it. Then, in the second one, the ranges will already be set, and you will be able to just use those ranges.

The invariant here is that, if you fallback, you will be able to get the result of the fallback, (which is a sympy.Symbol) and add its range as [-size, size-1] (as it's not wrapped at this point). Then any call after it should work as expected.
You can even assert that the returned symbol is not in var_to_range before adding it.

@ColinPeppler
Copy link
Contributor Author

ColinPeppler commented Jun 12, 2024

I realize that indirect0 was intially created with the assumption it's non-negative. This means 0 <= expr will always evaluate to true, although it can be negative (atleast until wrap_expr).

def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.
"""
# This should never be used for creating shape/stride symbols, as those
# should all be allocated before Inductor.
assert prefix != SymT.SIZE
# NOTE: shape symbols are positive (> 0), but index variables are only
# non-negative (>= 0).
return make_symbol(prefix, idx, integer=True, nonnegative=True)

So, I tossed that assumption out since indirect0 can be negative until it gets wrapped.

…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
Add `indirectX  symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: 7199ae4a66ff5132e40b4172229c65a4cfd08e66
Pull Request resolved: #128378
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
Add `indirectX  symbols with a default range (-inf, +inf) to `self.var_to_range` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 12, 2024
ghstack-source-id: e242befc8ff5528c1431eb9b1e7d1267db47869b
Pull Request resolved: #128378
Comment on lines 353 to 356
if indirect_var not in self.var_to_range:
lower, upper = -upper_bound(size), upper_bound(size) - 1
indirect_range = (indirect_var, ValueRanges(lower, upper))
self.var_to_range = self.var_to_range + (indirect_range,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be turned into

Suggested change
if indirect_var not in self.var_to_range:
lower, upper = -upper_bound(size), upper_bound(size) - 1
indirect_range = (indirect_var, ValueRanges(lower, upper))
self.var_to_range = self.var_to_range + (indirect_range,)
assert indirect_var not in self.var_to_range
lower, upper = -upper_bound(size), upper_bound(size) - 1
indirect_range = (indirect_var, ValueRanges(lower, upper))
self.var_to_range = self.var_to_range + (indirect_range,)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach largely looks good. Just one question.

Also, it's a pity that, in general, we have to drop the is_positive assumption from SymPy, but it is true that it was incorrect as written.

from torch.utils._sympy.symbol import make_symbol

# Note: indirect index variables can be negative, positive or 0.
var = make_symbol(SymT.INDIRECT, len(self.indirect_vars), integer=True)
Copy link
Collaborator

@peterbell10 peterbell10 Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can an indirect variable be negative? The wrapping of negative values should happen before the real value is assigned to the symbolic value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, we do the wrapping within indirect_indexing in codegen, so I guess the positive condition is correct here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be "nonnegative" since we can have zeros

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I'll change that back because we know it must be non-negative after wrapping it.

wrapping of negative values should happen before the real value is assigned to the symbolic value
Btw I'm trying to understand what this looks like. Can you help me understand? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the codegen for indirect_index in common.py

…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 573b4ef35b0ead2fc080497050d44cc317d5b141
Pull Request resolved: #128378
…pagation"

Tries to fix #127677.

# Context

Just as peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0, 
    triton_poi_fused_index_select_0_xnumel, 
    grid=grid(triton_poi_fused_index_select_0_xnumel), 
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 469e53f432193ad75c0aa035405610ae3a040c4b
Pull Request resolved: #128378
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you!

@lezcano lezcano added release notes: inductor and removed release notes: fx release notes category labels Jun 14, 2024
@lezcano
Copy link
Collaborator

lezcano commented Jun 14, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal)

Details for Dev Infra team Raised by workflow job

@lezcano
Copy link
Collaborator

lezcano commented Jun 14, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@lezcano
Copy link
Collaborator

lezcano commented Jun 14, 2024

@pytorchbot merge -i

@lezcano lezcano closed this Jun 14, 2024
@lezcano lezcano reopened this Jun 14, 2024
@lezcano
Copy link
Collaborator

lezcano commented Jun 14, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 15 checks: inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal), inductor / cuda12.1-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100), inductor-periodic / cuda12.4-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100), linux-binary-manywheel / manywheel-py3_8-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_8-cuda12_1-test / test, linux-binary-manywheel / manywheel-py3_8-cuda11_8-test / test, trunk / linux-focal-rocm6.1-py3.8 / test (default, 1, 2, linux.rocm.gpu), trunk / linux-focal-rocm6.1-py3.8 / test (default, 2, 2, linux.rocm.gpu), trunk / linux-focal-rocm6.1-py3.8 / test (distributed, 1, 1, linux.rocm.gpu), trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 2, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 5, 5, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
…ytorch#128378)

Tries to fix pytorch#127677.

# Context

Just as @peterbell10 pointed out, we have the following scenario:
```
a = ops.indirect_indexing(...)
b = ops.index_expr(a, ...)
c = ops.indirect_indexing(b, ...)
```

We can repro this as:
```
def forward(self, arg0_1, arg1_1, arg2_1):
    iota = torch.ops.prims.iota.default(arg0_1, start = 0, step = 1, index=0),
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(arg1_1);
    index = torch.ops.aten.index.Tensor(iota, [repeat_interleave]);
    index_1 = torch.ops.aten.index.Tensor(arg2_1, [index]);
    return (index_1,)
```

which should generate a JIT py file like this:
```
def triton_poi_fused_index_select_0(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
    tmp1 = ks0
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    # check_bounds()
    tl.device_assert(((0 <= tmp4) & (tmp4 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp4 < ks0")

def call():
  arg0_1, arg1_1, arg2_1 = args
  buf1 = aten.repeat_interleave.Tensor(arg1_1)
  buf4 = empty_strided_cuda((u0, 64), (64, 1))
  triton_poi_fused_index_select_0.run(
    buf1, arg2_1, buf4, s0,
    triton_poi_fused_index_select_0_xnumel,
    grid=grid(triton_poi_fused_index_select_0_xnumel),
    stream=stream0)
```

# Issue
In our `IndexPropagation.indirect_indexing()` call we have `expr=indirect0` which is spawned in `LoopBodyBlock.indirect_indexing()`.
https://github.com/pytorch/pytorch/blob/3b555ba47713d489975a9bb6cb6c31975f805e3f/torch/_inductor/ir.py#L8154-L8160

When we try to see if we can prove its bounds, we fail because `indirect0` isn't in `var_ranges`.

# Approach
When creating `indirect` symbols from fallback, specify its range to be `[-size, size -1]` to avoid a lookup error with `indirectX`.

Pull Request resolved: pytorch#128378
Approved by: https://github.com/lezcano, https://github.com/peterbell10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants