-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[inductor] fix mkldnn linear binary fusion check ut #127296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127296
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 5fbf5bd with merge base 4d4d2a9 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() | ||
v = torch.randn(input_shape) | ||
other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we not convert dtype on mod
and input v
but convert it on "out" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Jiong.
Here we choose do not convert dtype on mod
and v
because we expected autocast to handle it. And for other
, the autocast will not cast it because add
is full-through op.
y = linear(x) + z
And currently we do not fuse "+ z" if z is float and linear(x) is lp.
Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -m 'Sorry for reverting you change but one of the tests is failing on trunk ROCm. Please help fix and reland the change https://github.com/pytorch/pytorch/actions/runs/9302535020/job/25606932572' -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit cdeb242. Reverted #127296 on behalf of https://github.com/huydhn due to Sorry for reverting you change but one of the tests is failing on trunk ROCm. Please help fix and reland the change https://github.com/pytorch/pytorch/actions/runs/9302535020/job/25606932572 ([comment](#127296 (comment)))
@zhuhaozhe your PR has been successfully reverted. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
In this PR: (1)Fix the unary fusion for bf16 conv/linear. Previously we registered same fusion pattern for `bf16. fp16`. And we do not check the dtype while matching the pattern. This results the `fp16` case matched the `bf16` pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoid `fp16` case matched `bf16` pattern. ``` def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): matched = _is_single_computation_op(computation_op, **lowp_dtype**)(match) # previously we do not check lowp_dtype here ``` It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of `generated_kernel`. If it is not fused, there will be an additional kernel to compute the post op. (2)Previous the ut ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_binary ``` dose not check the fusion status, fix it in this PR. (3)Extend `test_conv_binary` to test with lp. Pull Request resolved: pytorch#127296 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
In this PR:
(1)Fix the unary fusion for bf16 conv/linear.
Previously we registered same fusion pattern for
bf16. fp16
. And we do not check the dtype while matching the pattern. This results thefp16
case matched thebf16
pattern but in later replacement, we found that we have a float16 here which is not expected, so we do not fuse them. We fix it by checking dtypes to avoidfp16
case matchedbf16
pattern.It is not exposed before because we only check the match count, and the match count is anyway correct because we matched the pattern. To address this, we add check on number of
generated_kernel
. If it is not fused, there will be an additional kernel to compute the post op.(2)Previous the ut
dose not check the fusion status, fix it in this PR.
(3)Extend
test_conv_binary
to test with lp.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang