New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Fallback scatter when src dtype is bf16 #113204
Conversation
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113204
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e5b540f with merge base ee777a7 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: 884b08008526cdd9038b77fb11e0f52af487dcb9 Pull Request resolved: #113204
torch/_inductor/lowering.py
Outdated
@@ -3079,6 +3079,8 @@ def scatter_fallback( | |||
reduce_ty = "add" if fn == "aten.scatter_" else "sum" | |||
if ( | |||
reduce not in {None, reduce_ty} | |||
# tl.atomic_add does not support bf16 | |||
or src.get_dtype() in {torch.bfloat16} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaving this as a set since i assume there will need to be more things here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use tuple? for single elem set seems like overkill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch/_inductor/lowering.py
Outdated
@@ -3079,6 +3079,8 @@ def scatter_fallback( | |||
reduce_ty = "add" if fn == "aten.scatter_" else "sum" | |||
if ( | |||
reduce not in {None, reduce_ty} | |||
# tl.atomic_add does not support bf16 | |||
or src.get_dtype() in {torch.bfloat16} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use tuple? for single elem set seems like overkill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to guard this on what GPU we're running on? I think tl.atomic_add
only doesn't work on A100 GPUs and below.
@Chillee How can i check which gpu i am currently on? or rather how to say A100 or below? |
@eellison do you want me to combine both of those tl.atomic_add does not work checks? |
@Chillee as far as I can tell it's not supported in triton regardless of the device: triton-lang/triton#1387. @oulgen - yea, I think that makes sense. |
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: e7ff37d21d8bcc0e696641e24522528a67f2fd8a Pull Request resolved: #113204
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: 46879f885f8cfe15988aeda2abfe246498b30486 Pull Request resolved: #113204
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: 68621a60fc2c0e9918c7be789b91b710268dbf48 Pull Request resolved: #113204
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks!
@pytorchbot merge |
@@ -2926,6 +2926,11 @@ def _unsafe_index_put_(self, indices, values, accumulate=False): | |||
return index_put_impl_(self, indices, values, accumulate, check=False) | |||
|
|||
|
|||
def needs_fallback_due_to_atomic_add_limitations(dtype): | |||
# tl.atomic_add does NOT support the following types | |||
return dtype in {torch.int64, torch.bool, torch.bfloat16} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this check the device as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, i guess we should only do the pessimization in cuda mode and not cpu mode.
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
cc @htyu |
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: ef37bea53dbeadd12241f96b776faa216df79f01 Pull Request resolved: #113204
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass ghstack-source-id: ca07bd6e82a81bc327e0c26216e6841c8a96987b Pull Request resolved: #113204
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#113204)" Summary: Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression) Test Plan: CI Differential Revision: D51287379
…#113204)" (pytorch#113599) Summary: Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression) Test Plan: CI Differential Revision: D51287379
…#113204)" (pytorch#113599) Summary: Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression) Test Plan: CI Reviewed By: xw285cornell Differential Revision: D51287379
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass Pull Request resolved: pytorch#113204 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler