Skip to content

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 6, 2025

Stack from ghstack (oldest at bottom):

Perform multiple tile reductions concurrently, with each tile reduced to a separate root.

  • The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.

  • Currently supports NVLink SHARP scope only.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Oct 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164757

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit e17e8b3 with merge base a707042 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Oct 6, 2025
kwen2501 added a commit that referenced this pull request Oct 6, 2025
ghstack-source-id: d62d23b
Pull-Request: #164757
@kwen2501 kwen2501 requested review from fduwjj, fegin and ngimel October 6, 2025 18:31
@kwen2501 kwen2501 added the release notes: distributed (symm_mem) release note label for symmetric memory label Oct 6, 2025
Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good on UX part

root = rank;
}
i++;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that root != world_size here?

Copy link
Contributor Author

@kwen2501 kwen2501 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation uses root == world_size to indicate that current rank does not need to reduce any tile. (Yet it still calls into this API to fulfill the collective requirement). You can see the "Note" above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have a test for it though (root==world_size), do we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_multi_root_tile_reduce, when root_ratio is 2, we will exercise this case.
root_ratio=2 means only half of the ranks are root, the rest of ranks will provide root==world_size here to skip the reduction.

root = rank;
}
i++;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have a test for it though (root==world_size), do we?

- `reduce_op` is the reduction operation to perform. Currently only "sum" is supported.
*/
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
TORCH_CHECK(out_tile.dtype() == at::kFloat, "Only float is supported");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you support at least BFloat16 also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in base PR #162243

for (auto& in_tile : in_tiles) {
TORCH_CHECK(in_tile.dtype() == at::kFloat, "Only float is supported");
c10d::symmetric_memory::rendezvous(in_tile, group_name);
if (roots[i] == rank) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check that roots[i] is valid (>=0 and < world_size)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added now.

int nblocks = at::ceil_div(
out_tile.numel() * out_tile.element_size(),
(int64_t)THREADS_PER_BLOCK * 16);
nblocks = std::min(nblocks, 16);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why limit at 16? I think for cuda backend we limit at 24 at least, maybe we even need more for blackwell

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set it to 24 now.

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 7, 2025
ghstack-source-id: c47225e
Pull-Request: #164757
@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 8, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 8, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x ef426ec413907290b5c25b0005bf265ec2d0e6eb returned non-zero exit code 1

Auto-merging test/distributed/test_nvshmem.py
CONFLICT (content): Merge conflict in test/distributed/test_nvshmem.py
Auto-merging torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp
Auto-merging torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu
Auto-merging torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh
error: could not apply ef426ec4139... [SymmMem] Multi-root tile reduction
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 8, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants