-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[inductor] Adding a way to force fusion of int_mm with mul #111125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: unsure if there's a more general way to set it up rather than forcing the order of teh mul's based on their shape. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111125
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit 1f49cab with merge base 898482f ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 5a534a9 Pull Request resolved: #111125
I don't understand why we need a pattern for this? Why don't we just fix it so that it does prioritize this fusion? |
@Chillee - I agree with you. This kernel also uses block ptrs. It's a starting point. This isn't supposed to be the final version. |
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c3de3db Pull Request resolved: #111125
yeah this would ideally be temporary until we fix it so it prioritizes the fusion, the reason we needed a pattern for this is mostly because i knew how to do that but couldn't figure out how to fix the fusion. |
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7016102 Pull Request resolved: #111125
|
||
def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): | ||
out_dtype = ( | ||
torch.promote_types(mat3.get_dtype(), torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.promote_types(mat3.get_dtype(), torch.int32) | |
functools.reduce(torch.promote_types, [ | |
mat1.get_dtype(), | |
mat2.get_dtype(), | |
mat3.get_dtype() | |
]) |
I am assuming this will match the semantics of aten._int_mm
, but you should check to make sure.
Will aten._int_mm promote to torch.int32 even if none of the inputs are int32? What about int64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int_mm only does int8 int8 -> int32 matmuls so i don't think those others are possible.
with torch.promote_types(mat3.get_dtype(), torch.int32) i was trying to set the output type to be what you'd naturally get from an int32 * float16 so that it matches the non torch-compiled model. What is the benefit of including the mat1 and mat2 dtypes if what we really want is the promote from the int_mm output which we know the dtype of?
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 120dd57 Pull Request resolved: #111125
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -m 'Sorry for reverting your change, but it fails on ROCm https://hud.pytorch.org/pytorch/pytorch/commit/f4297576e63e4110f6bdf2522ae6a5fb4c7f3816' -c nosignals The error is |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot revert -m 'Sorry for reverting your change, but it fails on ROCm https://hud.pytorch.org/pytorch/pytorch/commit/f4297576e63e4110f6bdf2522ae6a5fb4c7f3816' -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
@HDCharles your PR has been successfully reverted. |
…111125)" This reverts commit f429757. Reverted #111125 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it fails on ROCm https://hud.pytorch.org/pytorch/pytorch/commit/f4297576e63e4110f6bdf2522ae6a5fb4c7f3816 ([comment](#111125 (comment)))
Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 938093a Pull Request resolved: pytorch/pytorch#111125
Stack from ghstack (oldest at bottom):
Summary: When doing quantization int_mm -> mul or int_mm -> mul ->
to(dtype) is an extremely common op pattern which is currently not
handled well by inductor. Ideally, since the output of
int_mm has dtype int32 we'd prefer to only realize a smaller dtype like
bf16 or float16. Currently inductor doesn't have a way to force this, in
many cases the mul gets fused with a bunch of subsequent pointwise
ops from the dequant creating an increase in memory overhead and a general
slowdown compared to the fused version.
Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed.
Test Plan: python test/inductor/test_pattern_matcher.py -k
"int_mm_mul"
Reviewers:
Subscribers:
Tasks:
Tags:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler