Skip to content

Conversation

@karthickai
Copy link
Collaborator

@karthickai karthickai commented Nov 7, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167380

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b3e3ede with merge base 1727a71 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

karthickai added a commit that referenced this pull request Nov 7, 2025
ghstack-source-id: 06b2eb8
Pull Request resolved: #167380
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Nov 7, 2025
ghstack-source-id: 253359b
Pull Request resolved: #167380
@eellison eellison requested review from mlazos and shunting314 and removed request for eellison November 11, 2025 19:31
"tl.atomic_add" in code[0],
"bf16 should generate tl.atomic_add",
)
torch.testing.assert_close(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always pass since result is an alias to output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the catch, I create a new expected tensor instead of reusing output.

and dtype == torch.bfloat16
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and config.bfloat16_atomic_adds_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this config still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the config is not needed anymore I removed it as well as this test case (

@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
@unittest.skipIf(
config.is_fbcode(),
"bfloat16 atomic add is supported in fbcode, so we won't fallback",
)
def test_index_add_fallback(self):
def f(x, y):
return torch.index_select(x, 0, y)
x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check("aten.index_add")
fc.run(bw_code)
self.assertEqual(f(x_ref, y_ref), out)
)

Fixes: #97016 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Nov 12, 2025
ghstack-source-id: d068fdd
Pull Request resolved: #167380
return False
else:
return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16])
return dtype in OrderedSet([torch.int64, torch.bool])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need fallback atomic_add for bfloat16 if compute_capacility < 90?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we still need fallback. I've added check sm < (9, 0) to fallback.

Fixes: #97016 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Nov 12, 2025
ghstack-source-id: e214b26
Pull Request resolved: #167380
@mlazos
Copy link
Contributor

mlazos commented Nov 12, 2025

Hi why are we removing this?

This is used internally

@karthickai
Copy link
Collaborator Author

Hi why are we removing this?

This is used internally

We are addressing the issue (#97016). Previously, atomic_add did not support bf16, so a fallback was implemented. triton now supports bf16 atomic add, so we are removing this fallback

@mlazos
Copy link
Contributor

mlazos commented Nov 13, 2025

Hi why are we removing this?
This is used internally

We are addressing the issue (#97016). Previously, atomic_add did not support bf16, so a fallback was implemented. triton now supports bf16 atomic add, so we are removing this fallback

Okay sounds good, you may want to ensure the config is not used internally before removing it may cause errors when landing this.

@karthickai
Copy link
Collaborator Author

Hi why are we removing this?
This is used internally

We are addressing the issue (#97016). Previously, atomic_add did not support bf16, so a fallback was implemented. triton now supports bf16 atomic add, so we are removing this fallback

Okay sounds good, you may want to ensure the config is not used internally before removing it may cause errors when landing this.

I did an internal code search for the config bfloat16_atomic_adds_enabled, and only the files we touched are affected. I think it’s safe to merge.

@karthickai
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shunting314 shunting314 self-requested a review November 13, 2025 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants