New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inductor: improve the index range check for index_expr vec check #102263
inductor: improve the index range check for index_expr vec check #102263
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102263
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c980fcf: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 251417a3fc16e87771ec92085bd6c896244ab7eb Pull Request resolved: #102263
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this fix is correct. The issue that this is hitting comes from the fact that @peterbell10 's optimisation is now providing us with better index variables at compile time (which is good), and this uncovers a latent bug in the C++ code generation. In particular, the following code
pytorch/torch/_inductor/codegen/cpp.py
Lines 1992 to 2009 in 4882cd0
opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() | |
assert opt_ctx | |
max_expr = expr.replace( | |
ir.ModularIndexing, mod_indexing_rep | |
).replace(ir.FloorDiv, indexing_div_rep) | |
min_expr = max_expr | |
for idx in range(len(self.ranges)): | |
max_expr = sympy.maximum( | |
max_expr, | |
self.itervars[idx], | |
sympy.Interval(0, self.ranges[idx]), | |
) | |
min_expr = sympy.minimum( | |
min_expr, | |
self.itervars[idx], | |
sympy.Interval(0, self.ranges[idx]), | |
) | |
i32_iinfo = numpy.iinfo(numpy.int32) |
is not correct.
Here you are trying to bound the value of a multivariate function on a cube (a product of intervals). In the case of this example, the intervals are
[0,7] x [0,7]
and the function is i0**2 - 2 * i0 * i1 + i1 ** 2
(that's the lowering of (i0 - i1)**2
). The code there is trying to find the maximum over this region by calling scipy.maximum
. This fails, because maximum
just works for univariate functions.
This optimisation is performed in the triton path in
https://github.com/pytorch/pytorch/blob/main/torch/_inductor/optimize_indexing.py
via an eager algorithm. A proper fix here would lift this logic to common.py
and use it on the C++ codegen as well. This should be fairly simple, as the code there works at an IR level.
I have #100549 open that touches that code as well (but doesn't change the API). I'll try to have it merged today so that we don't step on each other.
@@ -376,6 +376,22 @@ def fn(a): | |||
a = torch.randn(1, 3) | |||
self.common(fn, (a,)) | |||
|
|||
def test_index_propagation_issue_102065(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A slightly more concise repro is the following:
import torch
@torch.compile
def fn(x):
x = torch.arange(x.numel())
return (x.unsqueeze(0) - x.unsqueeze(1))**2
fn(torch.randn(8))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed with @XiaobingSuper offline. Probably, we can extend the optimize_indexing.py
to reduce the precision of dtype
of index_expr
too.
Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
…vec check ghstack-source-id: be24fae8e6b8ae1e65e8dcf19ae780c530ebd288 Pull Request resolved: #102263
Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
…vec check ghstack-source-id: 465b6bc6186d391358a2eedba49e78c7bf4be4a7 Pull Request resolved: #102263
Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
… check" Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
… check" Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
ghstack-source-id: 9afb4e04462ce04841fcbe5ce2f19cae86f45c2c Pull Request resolved: #102263
ghstack-source-id: 056d86d54d85945756840c6d1f245fc7287cf342 Pull Request resolved: #102263
ghstack-source-id: 9df49eddaae1eebe4c3a92c2686f50d8b492c778 Pull Request resolved: #102263
… check" Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
ghstack-source-id: 7b298e516127342d2713bc8ae86252998115ab27 Pull Request resolved: #102263
torch/_inductor/optimize_indexing.py
Outdated
if len(free_symbols) == 0: | ||
return ValueRanges(expr, expr) | ||
|
||
def replace_symbols_for_deriv(expr, ignore_mod=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ignore_mod
is never used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
torch/_inductor/codegen/cpp.py
Outdated
for k, v in zip(self.itervars, self.ranges) | ||
if k in free_symbols | ||
} | ||
if not vars_ranges: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would this happen? No free symbols?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there has a test cast: expr is s0
, which is not in vars_ranges
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I guess it is ks0
in the case of dynamic shape"? Not necessarily in this PR but we may consider to guard the range within int32 range to get more optimizations.
and expr <= i32_iinfo.max | ||
and expr >= i32_iinfo.min | ||
) | ||
expr_ranges = get_expr_range(expr, vars_ranges) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does get_expr_range
have the assumption that all free symbols in expr
exist in vars_range
? Would expr
contains symbols that are not part of the itervars, like tmp
vars from indirect indexing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, there doesn't have such an assumption that all free symbols in expr exist in vars_range.
… check" Fix #102065. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 [ghstack-poisoned]
ghstack-source-id: aac5c1e3eb9d2b7967539a7f88d95acd3257fd2f Pull Request resolved: #102263
for k, v in zip(self.itervars, self.ranges) | ||
if k in free_symbols | ||
} | ||
if not vars_ranges or len(vars_ranges) != len(free_symbols): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this equivalent to any(x.startwith("tmp") for x in free_symbols)
? In other words, here we are asking whether the expression has indirect indexing, as we don't have bounds at hand for those here.
If that's the case, using this other more explicit condition and leaving a comment would help to understand the logic here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is equivalent to any(x.startwith("tmp") for x in free_symbols)
. There has a case that the expr is just kernel input(s0
), which is not related to itervars
and it is not buffer indexing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, then same question with startswith("tmp") or startswith("s")
. It'd be nice to know more clearly which cases we know how to treat and which we don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with you that knowing more clearly what cases have is better. Let me add this as a TODO work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock, but I think that a better check here would be to make sure that free_symbols
is a subset of self.itervars
, i.e., we have all the information to compute the upper bound on the index explicitly.
for k, v in zip(self.itervars, self.ranges) | ||
if k in free_symbols | ||
} | ||
if not vars_ranges or len(vars_ranges) != len(free_symbols): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock, but I think that a better check here would be to make sure that free_symbols
is a subset of self.itervars
, i.e., we have all the information to compute the upper bound on the index explicitly.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…orch#102263) Fix pytorch#102065. Pull Request resolved: pytorch#102263 Approved by: https://github.com/lezcano, https://github.com/peterbell10, https://github.com/jgong5
…orch#102263) Fix pytorch#102065. Pull Request resolved: pytorch#102263 Approved by: https://github.com/lezcano, https://github.com/peterbell10, https://github.com/jgong5
Stack from ghstack (oldest at bottom):
Fix #102065.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10