New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inductor: tigher upperbound for rblock scaling #109839
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109839
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ee9e5a5 with merge base 2c1554a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 9ca6707fadbe1ef4a319822dccf1325721bd6055 Pull Request resolved: #109839
# make sure rblock is not too small | ||
if rblock <= 64: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to do this before the rest of the logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally good to have simple checks first so later code can assume rblock > 64
# < 65536 / ((65536 / max_threads_per_multi_processor) * 32 * num_warps) | ||
# = max_threads_per_multi_processor / (32 * num_warps) | ||
# Using a tigher upper bound can reveal more optimization opportunities. | ||
max_blocks_per_sm = max(65536 // nreg_per_block, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, if 65536 // nreg_per_block < 1
there is going to be register spillage, so we probably want to decrease the RBLOCK as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, provided that there are no register spillings, we have the tighter bound
max_warps_per_sm = 64
max_blocks_per_sm = min(65536 // nreg_per_block, max_warps_per_sm // triton_conf.num_warps)
This would also give that if 65536 // nreg_per_block, max_warps_per_sm < max_warps_per_sm // triton_conf.num_warp
and nreg < 255//2
, we can probably increase the RBLOCK
and increase the efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, if 65536 // nreg_per_block < 1 there is going to be register spillage, so we probably want to decrease the RBLOCK as well.
So I use max(65536 // nreg_per_block, 1)
just to be safe. nreg_per_block actually is the register usage after considering register spills. The spilled registers are not counted in nreg_per_block . So if 65536 // nreg_per_block < 1
probably the kernel will not able to be run.
Register spill is implicitly handled here actually. When register spills, each thread gonna uses 255 registers (the maximum # of register a thread can use). In that case, we have large chance to trigger the decreasing of RBLOCK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_blocks_per_sm = min(65536 // nreg_per_block, max_warps_per_sm // triton_conf.num_warps)
This is equivalent to the current code. My comment in the code shows why the following is true:
65536 // nreg_per_block < max_warps_per_sm // triton_conf.num_warps
here.
Scaling up RBLOCK can do harm since we increase register usage. Even if there is no register spills, large register usage can still hurt occupancy and may result in under-performing configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When register spills, each thread gonna uses 255 registers (the maximum # of register a thread can use). In that case, we have large chance to trigger the decreasing of RBLOCK.
Why don't we handle it explicitly now that we are at it?
Here is the perf test result: link
So the PR improve things in general but overall metric is neutral. |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot label "topic: not user facing" |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Previously when we deciding if dynamically scaling down rblock, we use the following formule to compute the upper bound of number of blocks per sm:
This is correct but it's a bit loose and some times because of the loose upper bound, we skip some optimization opportunities.
The new upper bound is: 65536 / n_reg_used_by_each_block . This is a tighter upper bound and can be helpful if the kernel uses too many registers (i.e. much larger than 32).
For kernel https://gist.github.com/shunting314/59aeafd297ed8ff03aa12030a2dd41ae (this is a real kernel inductor generates for HF), the change improve its perf from:
0.485ms 0.332GB 684.29GB/s
to
0.240ms 0.332GB 1382.70GB/s
. The perf is bad previsouly because of register spills
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov