-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[Inductor] Remove bf16 fallback for atomic_add #167380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167380
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit b3e3ede with merge base 1727a71 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
test/inductor/test_torchinductor.py
Outdated
| "tl.atomic_add" in code[0], | ||
| "bf16 should generate tl.atomic_add", | ||
| ) | ||
| torch.testing.assert_close( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will always pass since result is an alias to output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the catch, I create a new expected tensor instead of reusing output.
| and dtype == torch.bfloat16 | ||
| and torch.cuda.is_available() | ||
| and torch.cuda.get_device_capability() >= (9, 0) | ||
| and config.bfloat16_atomic_adds_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this config still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the config is not needed anymore I removed it as well as this test case (
pytorch/test/inductor/test_cuda_repro.py
Lines 2260 to 2283 in e8d411e
| @skipCUDAIf( | |
| not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90" | |
| ) | |
| @unittest.skipIf( | |
| config.is_fbcode(), | |
| "bfloat16 atomic add is supported in fbcode, so we won't fallback", | |
| ) | |
| def test_index_add_fallback(self): | |
| def f(x, y): | |
| return torch.index_select(x, 0, y) | |
| x = torch.randn( | |
| 2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True | |
| ) | |
| y = torch.ones(713268, dtype=torch.int64, device="cuda") | |
| x_ref = x.clone().detach().requires_grad_(True) | |
| y_ref = y.clone().detach() | |
| out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y)) | |
| fc = FileCheck() | |
| fc.check("aten.index_add") | |
| fc.run(bw_code) | |
| self.assertEqual(f(x_ref, y_ref), out) |
Fixes: #97016 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
torch/_inductor/utils.py
Outdated
| return False | ||
| else: | ||
| return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16]) | ||
| return dtype in OrderedSet([torch.int64, torch.bool]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need fallback atomic_add for bfloat16 if compute_capacility < 90?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we still need fallback. I've added check sm < (9, 0) to fallback.
Fixes: #97016 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
|
Hi why are we removing this? This is used internally |
We are addressing the issue (#97016). Previously, atomic_add did not support bf16, so a fallback was implemented. triton now supports bf16 atomic add, so we are removing this fallback |
Okay sounds good, you may want to ensure the config is not used internally before removing it may cause errors when landing this. |
I did an internal code search for the config |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes: #97016
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos