-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[SymmMem] Multi-root tile reduction #164757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164757
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit e17e8b3 with merge base a707042 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good on UX part
root = rank; | ||
} | ||
i++; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check that root != world_size
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation uses root == world_size
to indicate that current rank does not need to reduce any tile. (Yet it still calls into this API to fulfill the collective requirement). You can see the "Note" above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have a test for it though (root==world_size
), do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In test_multi_root_tile_reduce
, when root_ratio
is 2, we will exercise this case.
root_ratio=2
means only half of the ranks are root, the rest of ranks will provide root==world_size
here to skip the reduction.
root = rank; | ||
} | ||
i++; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have a test for it though (root==world_size
), do we?
- `reduce_op` is the reduction operation to perform. Currently only "sum" is supported. | ||
*/ | ||
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now"); | ||
TORCH_CHECK(out_tile.dtype() == at::kFloat, "Only float is supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you support at least BFloat16 also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in base PR #162243
for (auto& in_tile : in_tiles) { | ||
TORCH_CHECK(in_tile.dtype() == at::kFloat, "Only float is supported"); | ||
c10d::symmetric_memory::rendezvous(in_tile, group_name); | ||
if (roots[i] == rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check that roots[i]
is valid (>=0 and < world_size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added now.
int nblocks = at::ceil_div( | ||
out_tile.numel() * out_tile.element_size(), | ||
(int64_t)THREADS_PER_BLOCK * 16); | ||
nblocks = std::min(nblocks, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why limit at 16? I think for cuda backend we limit at 24 at least, maybe we even need more for blackwell
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set it to 24 now.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Perform multiple tile reductions concurrently, with each tile reduced to a separate root.
The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.
Currently supports NVLink SHARP scope only.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci