Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor: tigher upperbound for rblock scaling #109839

Closed
wants to merge 1 commit into from

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Sep 22, 2023

Stack from ghstack (oldest at bottom):

Previously when we deciding if dynamically scaling down rblock, we use the following formule to compute the upper bound of number of blocks per sm:

max_threads_per_multi_processo / (32 * num_warps)

This is correct but it's a bit loose and some times because of the loose upper bound, we skip some optimization opportunities.

The new upper bound is: 65536 / n_reg_used_by_each_block . This is a tighter upper bound and can be helpful if the kernel uses too many registers (i.e. much larger than 32).

For kernel https://gist.github.com/shunting314/59aeafd297ed8ff03aa12030a2dd41ae (this is a real kernel inductor generates for HF), the change improve its perf from:
0.485ms 0.332GB 684.29GB/s
to
0.240ms 0.332GB 1382.70GB/s

. The perf is bad previsouly because of register spills

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109839

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ee9e5a5 with merge base 2c1554a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Sep 22, 2023
ghstack-source-id: 9ca6707fadbe1ef4a319822dccf1325721bd6055
Pull Request resolved: #109839
Comment on lines +202 to +204
# make sure rblock is not too small
if rblock <= 64:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to do this before the rest of the logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally good to have simple checks first so later code can assume rblock > 64

# < 65536 / ((65536 / max_threads_per_multi_processor) * 32 * num_warps)
# = max_threads_per_multi_processor / (32 * num_warps)
# Using a tigher upper bound can reveal more optimization opportunities.
max_blocks_per_sm = max(65536 // nreg_per_block, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, if 65536 // nreg_per_block < 1 there is going to be register spillage, so we probably want to decrease the RBLOCK as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, provided that there are no register spillings, we have the tighter bound

max_warps_per_sm = 64
max_blocks_per_sm = min(65536 // nreg_per_block, max_warps_per_sm // triton_conf.num_warps)

This would also give that if 65536 // nreg_per_block, max_warps_per_sm < max_warps_per_sm // triton_conf.num_warp and nreg < 255//2, we can probably increase the RBLOCK and increase the efficiency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, if 65536 // nreg_per_block < 1 there is going to be register spillage, so we probably want to decrease the RBLOCK as well.

So I use max(65536 // nreg_per_block, 1) just to be safe. nreg_per_block actually is the register usage after considering register spills. The spilled registers are not counted in nreg_per_block . So if 65536 // nreg_per_block < 1 probably the kernel will not able to be run.

Register spill is implicitly handled here actually. When register spills, each thread gonna uses 255 registers (the maximum # of register a thread can use). In that case, we have large chance to trigger the decreasing of RBLOCK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_blocks_per_sm = min(65536 // nreg_per_block, max_warps_per_sm // triton_conf.num_warps)

This is equivalent to the current code. My comment in the code shows why the following is true:

65536 // nreg_per_block < max_warps_per_sm // triton_conf.num_warps

here.

Scaling up RBLOCK can do harm since we increase register usage. Even if there is no register spills, large register usage can still hurt occupancy and may result in under-performing configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When register spills, each thread gonna uses 255 registers (the maximum # of register a thread can use). In that case, we have large chance to trigger the decreasing of RBLOCK.

Why don't we handle it explicitly now that we are at it?

@shunting314
Copy link
Contributor Author

shunting314 commented Sep 22, 2023

Here is the perf test result: link

  • the compilation time is overall neutral
  • I see 1% perf increase in TIMM but since I see no green cells, I'll consider it as noise.
  • I see 1% perf decrease in HF. But the PR should not cause real perf degradation since the net effect is we consider 1 more triton config in some cases. Actually I see no red cells but 1 green cells for HF: DebertaV2ForQuestionAnswering increase from 2.44 to 2.62. Also the model that the spilled kernel coming from is BlenderbotForCausalLM . BlenderbotForCausalLM also increase from 1.17x to 1.18x. The improvement for BlenderbotForCausalLM is not very large probably because the kernel does not take a large latency.

So the PR improve things in general but overall metric is neutral.

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 22, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@shunting314
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 22, 2023
@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/shunting314/85/head branch September 26, 2023 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants