New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-GPU support to FutureNCCL #48500
Conversation
💊 CI failures summary and remediationsAs of commit 0928ffa (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 24 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
torch/lib/c10d/ProcessGroupNCCL.hpp
Outdated
@@ -197,9 +198,7 @@ class ProcessGroupNCCL : public ProcessGroup { | |||
// | |||
// If created by WorkNCCL's getFuture API, FutureNCCL has a reference to | |||
// WorkNCCL's cudaEvents, NCCL collective's outputs, and the device index of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index -> indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed (it'll be visible once I re-export it)
I was thinking of doing a complete pass over comments/docstrings/... at the very end, once we have split out CUDAFuture, to document the reasons behind our design choices, the subtleties around streams usage, and such.
torch/lib/c10d/ProcessGroupNCCL.hpp
Outdated
// outputs' device. Its value is NCCL collective's | ||
// outputs. FutureNCCL only supports single-process single-device mode where | ||
// the size of outputs is equal to 1. | ||
// outputs' device. Its value is NCCL collective's outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device -> devices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed (it'll be visible once I re-export it)
// value, because the user's callback could use those other devices. | ||
std::vector<at::cuda::CUDAStream> streams; | ||
for (c10::DeviceIndex idx = 0; idx < c10::cuda::device_count(); idx++) { | ||
// FIXME Should we find a way to allow to change the priority of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we need high-priority streams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually have no idea. Do you know when high- vs low-priority streams were used in ProcessGroupNCCL? What was the reason? Does it still apply here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have never seen a user trying to configure stream priority for NCCL ops. And I don't think it's possible with the init_process_group
API. Users probably will have to use the un-documented ProcessGroupNCCL
ctor API.
Besides, I am also not sure how much impact the stream priority can have on the schedule and how visible that is to the e2e perf.
If there is no specific use case for now, we probably can keep it simple for now? It will be interesting to run some experiments to quantify its impact on perf in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failures look relevant:
ov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [3.027s]: test_all_gather_multigpu (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.827s]: test_all_gather_multigpu_complex (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.926s]: test_all_reduce_sum_cuda (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.927s]: test_all_reduce_sum_cuda_async (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.927s]: test_all_reduce_sum_cuda_complex (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.927s]: test_broadcast_cuda (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.932s]: test_nccl_high_priority_stream (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.927s]: test_reduce_multigpu (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.927s]: test_reduce_sum_cuda (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ======================================================================
Nov 26 23:08:31 FAIL [2.928s]: test_reduce_sum_cuda_twice (__main__.TestDistBackendWithSpawn)
Nov 26 23:08:31 ----------------------------------------------------------------------
Nov 26 23:08:31 Traceback (most recent call last):
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 278, in wrapper
Nov 26 23:08:31 self._join_processes(fn)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 395, in _join_processes
Nov 26 23:08:31 self._check_return_codes(elapsed_time)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 450, in _check_return_codes
Nov 26 23:08:31 msg="Expected zero exit code but got {}".format(first_process.exitcode)
Nov 26 23:08:31 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1149, in assertEqual
Nov 26 23:08:31 self.assertTrue(result, msg=msg)
Nov 26 23:08:31 AssertionError: False is not true : Expected zero exit code but got -11
Nov 26 23:08:31
Nov 26 23:08:31 ----------------------------------------------------------------------
torch/lib/c10d/ProcessGroupNCCL.hpp
Outdated
auto fut = c10::make_intrusive<FutureNCCL>(); | ||
// The new future needs the DataPtr extractor when it gets marked complete | ||
// but this might happen immediately inline or in parallel by another | ||
// thread. In both these cases this would/might happen before the user has | ||
// time to set their own DataPtr extractor, which might lead to failures | ||
// if the default extractor can't handle some of the user's types. | ||
// Therefore we propagate our extractor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized that since we now use the DataPtr extractor also in markCompleted, we need to do this ugly thing. I suspect that this was the cause of the CI failures (we'll see shortly). I welcome ideas on how to do this better...
It turned out the segfault came from accessing the parameters rather than the fields when doing the assert inside the constructor (i.e., the difference was a trailing underscore). This led to dereferencing a null pointer. |
torch/csrc/jit/python/pybind_utils.h
Outdated
@@ -224,6 +224,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper | |||
// vector that has exactly one tensor. | |||
static std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrExtractor( | |||
const at::IValue& value) { | |||
TORCH_INTERNAL_ASSERT(PyGILState_Check() == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also added this because I was seeing some deadlocks (before I fixed the segfault) which seemed unrelated, and my suspicion was that it came from a reentrant GIL acquisition, possibly due to the DataPtr extractor now being used in more places. However the CI at the moment doesn't seem to trigger neither the deadlock nor this check.
The deadlocks persist so, since I couldn't repro them on my local machine, I logged into the CircleCI executor and ran one test with CUDA_LAUNCH_BLOCKING=1. Here's the stack traces: https://gist.github.com/lw/64c4e677476461887c479c6fcbcf0d00. My current best guess for what is happening is:
For some more context on why it's happening:
Note that I don't think the order of operations here is critically relevant. I suspect that even if we invoked the callback after we performed the NCCL operation a similar issue could occur. I have a few ideas on how we could tackle this:
I don't really have any preference yet among these options, I don't like any of them. I'll have to think a bit more about it. The "core" design choice we need to make, I think, is whether we want to initialize all devices in all processes. When a process creates a CUDA context on a device this will consume memory on that device, which is undesirable if users don't intend to use the device. So, to rephrase the issue: can we assume that if users didn't explicitly prevent us from using some devices (by using CUDA_VISIBLE_DEVICES) then they are fine with us touching all of them? |
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This actually looks fine to me if we are only talking about Python callbacks. IIUC, if users will touch a different device in the cb, the Python program will initialize the context first. And if they also wanna use the result tensor in that new device, I would assume they need a to/copy operation, which will synchronize streams on both src and dst. Did I miss any situation that this might go wrong? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -212,28 +211,37 @@ class ProcessGroupNCCL : public ProcessGroup { | |||
public: | |||
explicit FutureNCCL( | |||
at::IValue value, | |||
c10::DeviceIndex deviceIndex, | |||
std::vector<c10::DeviceIndex> deviceIndices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we expect devices to be distinct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do, although I think it should be fine if there are duplicated. When the CUDAFuture itself determines the set of devices (inside markCompleted) it explicitly deduplicated them. Let me add a check to the other constructor (the one invoked by ProcessGroupNCCL) to ensure they are distinct.
One issue I was concerned about was when the user's callback returned a value that contained a tensor in another device. Such a value will be stored in the child Future (the one returned by then()), which will need to record an event for each device that is used by the new value, and will do so in the current stream. Suppose the parent Future has a value that resides on device 0, and thus when invoking the callback it only gets a stream from the pool for device 0, and doesn't touch the other devices. If device 1 is not initialized yet, its current stream at this point will be the default stream (i.e., nullptr). The user callback can very well use device 1 (and thus initialize it), but the user's current stream for device 1 will be reset when the callback exits, meaning the child Future can't access that stream to record an event on it. If the user properly synchronizes their own stream with the default stream then everything should still work. However the concern is that using the default stream has poor performance, since it cannot run in parallel with any other stream AFAIK. Another concern is that we could incur in the same deadlock issue with NCCL that we experienced at some point in the iterations of this PR. Especially with multiple threads, it could happen that one user callback initializes device 1 while another thread is performing a NCCL callback, and if these two operations race in the wrong way they could get stuck. However this issue might be minor, as from what I could observe it doesn't seem to affect newer devices, and moreover the user can easily circumvent it if they explicitly initialize the devices before starting any NCCL operations. |
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed). --- After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices. Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor. Differential Revision: [D25177556](https://our.internmc.facebook.com/intern/diff/D25177556/) [ghstack-poisoned]
This pull request has been merged in e294c2d. |
Stack from ghstack:
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed).
After the previous changes, this is now much simpler than it sounds. For the most part it just consists in repeating some operations multiple times, once for device (e.g., recording and blocking on events). Funnily, we already had a vector of events, even though we only ever stored one element in it (this probably comes from the fact that this is shared with WorkNCCL, which can hold more than one event). Here, we now also store a vector of device indices.
Perhaps the only non-trivial part of this is that now, for "follow-up" Futures (for callbacks), we can't know in advance which device the result will be on so we must determine it dynamically when we receive the result, by inspecting it. That's also easier than it sound because we already have a dataptr extractor.
Differential Revision: D25177556