Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor: improve the index range check for index_expr vec check #102263

Closed

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented May 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102263

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c980fcf:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XiaobingSuper added a commit that referenced this pull request May 25, 2023
ghstack-source-id: 251417a3fc16e87771ec92085bd6c896244ab7eb
Pull Request resolved: #102263
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this fix is correct. The issue that this is hitting comes from the fact that @peterbell10 's optimisation is now providing us with better index variables at compile time (which is good), and this uncovers a latent bug in the C++ code generation. In particular, the following code

opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
assert opt_ctx
max_expr = expr.replace(
ir.ModularIndexing, mod_indexing_rep
).replace(ir.FloorDiv, indexing_div_rep)
min_expr = max_expr
for idx in range(len(self.ranges)):
max_expr = sympy.maximum(
max_expr,
self.itervars[idx],
sympy.Interval(0, self.ranges[idx]),
)
min_expr = sympy.minimum(
min_expr,
self.itervars[idx],
sympy.Interval(0, self.ranges[idx]),
)
i32_iinfo = numpy.iinfo(numpy.int32)

is not correct.
Here you are trying to bound the value of a multivariate function on a cube (a product of intervals). In the case of this example, the intervals are [0,7] x [0,7] and the function is i0**2 - 2 * i0 * i1 + i1 ** 2 (that's the lowering of (i0 - i1)**2). The code there is trying to find the maximum over this region by calling scipy.maximum. This fails, because maximum just works for univariate functions.

This optimisation is performed in the triton path in
https://github.com/pytorch/pytorch/blob/main/torch/_inductor/optimize_indexing.py
via an eager algorithm. A proper fix here would lift this logic to common.py and use it on the C++ codegen as well. This should be fairly simple, as the code there works at an IR level.

I have #100549 open that touches that code as well (but doesn't change the API). I'll try to have it merged today so that we don't step on each other.

test/inductor/test_cpu_repro.py Outdated Show resolved Hide resolved
@@ -376,6 +376,22 @@ def fn(a):
a = torch.randn(1, 3)
self.common(fn, (a,))

def test_index_propagation_issue_102065(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A slightly more concise repro is the following:

import torch


@torch.compile
def fn(x):
    x = torch.arange(x.numel())
    return (x.unsqueeze(0) - x.unsqueeze(1))**2


fn(torch.randn(8))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @XiaobingSuper offline. Probably, we can extend the optimize_indexing.py to reduce the precision of dtype of index_expr too.

torch/_inductor/index_propagation.py Outdated Show resolved Hide resolved
Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request May 30, 2023
…vec check

ghstack-source-id: be24fae8e6b8ae1e65e8dcf19ae780c530ebd288
Pull Request resolved: #102263
Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request May 30, 2023
…vec check

ghstack-source-id: 465b6bc6186d391358a2eedba49e78c7bf4be4a7
Pull Request resolved: #102263
@XiaobingSuper XiaobingSuper marked this pull request as draft May 30, 2023 07:44
Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
… check"


Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
… check"


Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request May 30, 2023
ghstack-source-id: 9afb4e04462ce04841fcbe5ce2f19cae86f45c2c
Pull Request resolved: #102263
@XiaobingSuper XiaobingSuper marked this pull request as ready for review May 30, 2023 14:47
@XiaobingSuper XiaobingSuper changed the title inductor: fallback when IndexPropVars is empty inductor: improve the index range check for index_expr vec check May 30, 2023
XiaobingSuper added a commit that referenced this pull request May 30, 2023
ghstack-source-id: 056d86d54d85945756840c6d1f245fc7287cf342
Pull Request resolved: #102263
XiaobingSuper added a commit that referenced this pull request May 30, 2023
ghstack-source-id: 9df49eddaae1eebe4c3a92c2686f50d8b492c778
Pull Request resolved: #102263
… check"


Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request May 31, 2023
ghstack-source-id: 7b298e516127342d2713bc8ae86252998115ab27
Pull Request resolved: #102263
if len(free_symbols) == 0:
return ValueRanges(expr, expr)

def replace_symbols_for_deriv(expr, ignore_mod=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ignore_mod is never used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

for k, v in zip(self.itervars, self.ranges)
if k in free_symbols
}
if not vars_ranges:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this happen? No free symbols?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there has a test cast: expr is s0, which is not in vars_ranges.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I guess it is ks0 in the case of dynamic shape"? Not necessarily in this PR but we may consider to guard the range within int32 range to get more optimizations.

and expr <= i32_iinfo.max
and expr >= i32_iinfo.min
)
expr_ranges = get_expr_range(expr, vars_ranges)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does get_expr_range have the assumption that all free symbols in expr exist in vars_range? Would expr contains symbols that are not part of the itervars, like tmp vars from indirect indexing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there doesn't have such an assumption that all free symbols in expr exist in vars_range.

… check"


Fix #102065.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request May 31, 2023
ghstack-source-id: aac5c1e3eb9d2b7967539a7f88d95acd3257fd2f
Pull Request resolved: #102263
@XiaobingSuper XiaobingSuper requested a review from jgong5 May 31, 2023 06:26
for k, v in zip(self.itervars, self.ranges)
if k in free_symbols
}
if not vars_ranges or len(vars_ranges) != len(free_symbols):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to any(x.startwith("tmp") for x in free_symbols)? In other words, here we are asking whether the expression has indirect indexing, as we don't have bounds at hand for those here.
If that's the case, using this other more explicit condition and leaving a comment would help to understand the logic here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is equivalent to any(x.startwith("tmp") for x in free_symbols). There has a case that the expr is just kernel input(s0), which is not related to itervars and it is not buffer indexing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, then same question with startswith("tmp") or startswith("s"). It'd be nice to know more clearly which cases we know how to treat and which we don't.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with you that knowing more clearly what cases have is better. Let me add this as a TODO work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock, but I think that a better check here would be to make sure that free_symbols is a subset of self.itervars, i.e., we have all the information to compute the upper bound on the index explicitly.

for k, v in zip(self.itervars, self.ranges)
if k in free_symbols
}
if not vars_ranges or len(vars_ranges) != len(free_symbols):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock, but I think that a better check here would be to make sure that free_symbols is a subset of self.itervars, i.e., we have all the information to compute the upper bound on the index explicitly.

@XiaobingSuper
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 1, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

6 participants