-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding profiling capability to c++ ddp collective functions #46471
Conversation
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) ghstack-source-id: 114486799 Pull Request resolved: #46471
💊 CI failures summary and remediationsAs of commit f995345 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 31 times. |
💊 CI failures summary and remediationsAs of commit 3f824f2 (more details on the Dr. CI page):
3 failures not recognized by patterns:
Extra GitHub checks: 1 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 2 times. |
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Pull Request resolved: #46471 ghstack-source-id: 114689816 Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/)
Can you update the PR summary to include an example of what the profiling output looks like with your changes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this, the overall structure looks good!
// recordFunctionEndCallback_ is normally called in fininsh() function by | ||
// base class, but since finish is not called by WorkNCCL, we schedule this | ||
// function to be run when work is done. | ||
work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would only work in the case outputs.size() ==1
, we should validate that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we would just need to change the if statement from "f (work->recordFunctionEndCallback_) {
to if (work->recordFunctionEndCallback_ && can_profile) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when outputs size is greater than 1 (ans so can_profile is false), profiling_title is null and so recordFunctionEndCallback_ is not set. Added comment.
tensor = _build_tensor(src + 1).fill_(master_value if rank == src else worker_value) | ||
if cuda: | ||
tensor = tensor.cuda(rank_to_GPU[rank][0]) | ||
self.call_dist_op("reduce", async_op, dist.reduce, tensor, src, op, group_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also test that send
and recv
work as well for both gloo and NCCL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have not implemented it for send and recv. What should we test?
|
||
// Store references to outputs and futureNCCLCallbackStream to be used by | ||
// WorkNCCL::getFuture. | ||
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs); | ||
work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; | ||
|
||
if (work->recordFunctionEndCallback_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to enhance pointTopoint
as well to cover send and recv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not familiar with that part. Should we leave as future work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to add as a follow up PR, but can we file a GH issue for this (and any other follow up tasks?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using namespace torch::autograd::profiler; | ||
// Make sure enabling profile does not make any issue. Note, in single | ||
// process multi-device mode we do not expect any events be populated for | ||
// collective operations. | ||
enableProfiler({ProfilerState::CPU}); | ||
auto results = pg_->allreduce(tensors_); | ||
disableProfiler(); | ||
return results; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add a cpp test since we're already covered by python tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be useful from profiler perspective to just test the C++ API as well, which skips the parsing/event aggregation logic in profiler that happens in python. We have similar tests in https://github.com/pytorch/pytorch/blob/master/test/cpp/jit/test_misc.cpp#L2185-L2198
@@ -9,6 +9,7 @@ | |||
#include <c10/cuda/CUDAGuard.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to add tests to ProcessGroupMPITest to validate the profiling works correctly for that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, this is looking great overall! Left some comments inline. Could you also paste what the profiling output looks like in the PR description (you can get that with print(prof.key_averages().table())
in one of the tests)?
auto recordingFunction = std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE); | ||
if (recordingFunction->active) { | ||
recordingFunction->before(profiling_title, {}); | ||
std::function<void()> end_handler = [this, recordingFunction]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we std::move(recordingFunction)
since it's not used after this anymore to avoid a copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is shared_ptr and copy is cheap. right after this block also the extra copy is destroyed. If I wanted to use std:move line 62 will change to something much less readable:
std::function<void()> end_handler = [this, recordingFunction{std::move(recordingFunction)}]()
// recordFunctionEndCallback_ is normally called in fininsh() function by | ||
// base class, but since finish is not called by WorkNCCL, we schedule this | ||
// function to be run when work is done. | ||
work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we would just need to change the if statement from "f (work->recordFunctionEndCallback_) {
to if (work->recordFunctionEndCallback_ && can_profile) {
torch/lib/c10d/ProcessGroupNCCL.cpp
Outdated
@@ -1476,7 +1497,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base( | |||
comm, | |||
stream.stream()); | |||
}, | |||
OpType::ALLTOALL_BASE); | |||
OpType::ALLTOALL_BASE, | |||
"all_to_all"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this won't include shape information for now right? That's fine for this diff but just wanted to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true!
enableProfiler({ProfilerState::CPU}); | ||
auto results = pg_->allreduce(tensors_); | ||
disableProfiler(); | ||
return results; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally a test should assert something/some condition. Could we search through results
and verify there is an allreduce here? You can see https://github.com/pytorch/pytorch/blob/master/test/cpp/jit/test_misc.cpp#L2185-L2198 as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, since this is multi-device single process case, it should not be an event collected, unless number of devices happen to be 1. So I am not sure what we can check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do a single device per process test somehow? In the current version are the profiling results empty?
Also tangentially related - can you add a comment somewhere appropriate that specifies that it only works for single process per device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up here, is it possible to add some asserts on the expected result?
events = [event for event in prof.function_events if partial_key in event.name] | ||
return events[0] if len(events) > 0 else None | ||
|
||
recv_event = get_event(profiling_title) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: here we've matched on a partial key, can we also add an assert for what the exact name would look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to the postfix, since it could be e.g. nccl:reduce or gloo:reduce
work = op(*args, async_op=async_op, **kwargs) | ||
if async_op: | ||
work.wait() | ||
work._get_profiling_future().wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's necessarily the case, as work.wait()
could return without the profiling callback having ran and this wait ensures that the profiling callback (one that terminates the record function)is finished successfully. We had to do something similar with RPC (see: https://github.com/pytorch/pytorch/pull/38352/files) although in that case it is transparent to the user.
Ideally when there's profiling, similar to RPC, work.wait() should ensure the profiling callbacks have ran before returning. It might depend on future/work merge, though maybe we can get it to work now by modifying ::wait()
to await the profiling future if one exists
events = [event for event in prof.function_events if partial_key in event.name] | ||
return events[0] if len(events) > 0 else None | ||
|
||
recv_event = get_event(profiling_title) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Is it always a recv
event or are there different type of collective comm. calls here? If the latter is true can we have a name such as comm_event
which might be less confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no event for now is collected for send/recv. should we add that in followup PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is fine. I was mostly asking because I was curious why it was named recv_event
.
|
||
recv_event = get_event(profiling_title) | ||
if expect_event: | ||
self.assertEqual(recv_event.count, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a test where we do > 1 collective comm of the same type, and then > 1 collective comm of different types and validate the counts for those as well?
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Pull Request resolved: #46471 ghstack-source-id: 115289419 Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/)
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Pull Request resolved: #46471 ghstack-source-id: 115335707 Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/)
@@ -131,6 +156,10 @@ void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) { | |||
std::unique_lock<std::mutex> lock(mutex_); | |||
completed_ = true; | |||
exception_ = exception; | |||
if (recordFunctionEndCallback_) { | |||
recordFunctionEndCallback_(); | |||
recordFunctionEndCallback_ = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests that exercise this code path (finishAndThrow
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrzzd Just following up here - do we need to add these tests?
torch/lib/c10d/ProcessGroup.hpp
Outdated
@@ -137,6 +135,10 @@ class ProcessGroup { | |||
|
|||
OpType retrieveOpType(); | |||
|
|||
// Keeps track of the future responsible for profiling owner creation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does "profiling owner creation" mean? Do you just mean that this is a future that is complete when the profiling has finished?
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: unneeded extra lines?
if is_async: | ||
for work in works: | ||
work.wait() | ||
work._get_profiling_future().wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming that the test is flaky in the case where we remove this call? Is it okay to ship this prototype because we would need this explicit wait call from user code?
cc @pritamdamania87 - I guess we might be able to hack something but in the long term this will probably depend on future/work merge and we would implement this by adding a then callback like we do for RPC. Do you have any thoughts on what we can do currently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the awesome work! It looks great, but some larger things I think we should talk about:
- ProcessGroupMPI tests (looks like nccl and gloo are thoroughly tested)
- send/recv profiling follow up, and file GH issues for that and any other follow up tasks
- Discuss a design for removing
work._get_profiling_future().wait()
before we expose this to users. - Could you also add the profiling output to the PR summary?
Discussed about (3) offline, we will try to remove the call and the profiling should still be done transparently since the nccl callback should be invoked inline. |
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Pull Request resolved: #46471 ghstack-source-id: 115954679 Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, thanks for doing this! Had a couple of nits and 2 comments about testing. Feel free to land after taking a look at those.
Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/) [ghstack-poisoned]
Pull Request resolved: #46471 ghstack-source-id: 116018837 Differential Revision: [D23948397](https://our.internmc.facebook.com/intern/diff/D23948397/)
Codecov Report
@@ Coverage Diff @@
## gh/mrzzd/5/base #46471 +/- ##
===================================================
- Coverage 81.45% 81.43% -0.02%
===================================================
Files 1798 1798
Lines 188242 188300 +58
===================================================
+ Hits 153333 153345 +12
- Misses 34909 34955 +46 |
This pull request has been merged in 160db3d. |
Stack from ghstack:
Differential Revision: D23948397