Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Fallback scatter when src dtype is bf16 #113204

Closed
wants to merge 6 commits into from

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented Nov 7, 2023

basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113204

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e5b540f with merge base ee777a7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Nov 7, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: 884b08008526cdd9038b77fb11e0f52af487dcb9
Pull Request resolved: #113204
@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request topic: bug fixes topic category and removed module: inductor ciflow/inductor labels Nov 7, 2023
@@ -3079,6 +3079,8 @@ def scatter_fallback(
reduce_ty = "add" if fn == "aten.scatter_" else "sum"
if (
reduce not in {None, reduce_ty}
# tl.atomic_add does not support bf16
or src.get_dtype() in {torch.bfloat16}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving this as a set since i assume there will need to be more things here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tuple? for single elem set seems like overkill

@oulgen oulgen mentioned this pull request Nov 7, 2023
9 tasks
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a fallback for index_put from this issue while we're at it:

see: #97016

add test ?

@@ -3079,6 +3079,8 @@ def scatter_fallback(
reduce_ty = "add" if fn == "aten.scatter_" else "sum"
if (
reduce not in {None, reduce_ty}
# tl.atomic_add does not support bf16
or src.get_dtype() in {torch.bfloat16}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tuple? for single elem set seems like overkill

Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to guard this on what GPU we're running on? I think tl.atomic_add only doesn't work on A100 GPUs and below.

@oulgen
Copy link
Contributor Author

oulgen commented Nov 7, 2023

@Chillee How can i check which gpu i am currently on? or rather how to say A100 or below?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 7, 2023

@eellison do you want me to combine both of those tl.atomic_add does not work checks?

@eellison
Copy link
Contributor

eellison commented Nov 7, 2023

@Chillee as far as I can tell it's not supported in triton regardless of the device: triton-lang/triton#1387.

@oulgen - yea, I think that makes sense.

basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 7, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: e7ff37d21d8bcc0e696641e24522528a67f2fd8a
Pull Request resolved: #113204
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 7, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: 46879f885f8cfe15988aeda2abfe246498b30486
Pull Request resolved: #113204
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 8, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: 68621a60fc2c0e9918c7be789b91b710268dbf48
Pull Request resolved: #113204
@oulgen oulgen requested a review from eellison November 8, 2023 19:26
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks!

@oulgen
Copy link
Contributor Author

oulgen commented Nov 8, 2023

@pytorchbot merge

@@ -2926,6 +2926,11 @@ def _unsafe_index_put_(self, indices, values, accumulate=False):
return index_put_impl_(self, indices, values, accumulate, check=False)


def needs_fallback_due_to_atomic_add_limitations(dtype):
# tl.atomic_add does NOT support the following types
return dtype in {torch.int64, torch.bool, torch.bfloat16}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this check the device as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, i guess we should only do the pessimization in cuda mode and not cpu mode.

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@jansel
Copy link
Contributor

jansel commented Nov 8, 2023

cc @htyu

basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 8, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: ef37bea53dbeadd12241f96b776faa216df79f01
Pull Request resolved: #113204
@oulgen oulgen added the topic: not user facing topic category label Nov 8, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 9, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: ca07bd6e82a81bc327e0c26216e6841c8a96987b
Pull Request resolved: #113204
@oulgen
Copy link
Contributor Author

oulgen commented Nov 9, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/oulgen/30/head branch November 12, 2023 15:24
nmacchioni added a commit to nmacchioni/pytorch that referenced this pull request Nov 14, 2023
…#113204)"

Summary: Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression)

Test Plan: CI

Differential Revision: D51287379
nmacchioni added a commit to nmacchioni/pytorch that referenced this pull request Nov 14, 2023
…#113204)" (pytorch#113599)

Summary:

Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression)

Test Plan: CI

Differential Revision: D51287379
nmacchioni added a commit to nmacchioni/pytorch that referenced this pull request Nov 14, 2023
…#113204)" (pytorch#113599)

Summary:

Revert due to Llama 7b performance regression on Mi250x (83tok/s -> 79.5tok/s, ~4% regression)

Test Plan: CI

Reviewed By: xw285cornell

Differential Revision: D51287379
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

Pull Request resolved: pytorch#113204
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants