Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] CUDA-aware future for distributed #48305

Open
lw opened this issue Nov 20, 2020 · 3 comments
Open

[RFC] CUDA-aware future for distributed #48305

lw opened this issue Nov 20, 2020 · 3 comments
Labels
module: c10d Issues/PRs related to collective communications and process groups module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lw
Copy link
Contributor

lw commented Nov 20, 2020

We're working on adding CUDA support to the RPC module, and we're revamping the ProcessGroups by having them return Futures rather than Works. For those purposes we're designing a general-purpose Future class that can handle CUDA tensors (and thus work with CUDA streams). The FutureNCCL class would be merged with/replaced by this one.

Here is our current proposal, on which we welcome any kind of feedback!

  • The main point is: when do we mark such futures as complete? what are the requirements for that to happen? One option: we mark them complete when the value in them can be used by the user, but, in CUDA, “using” a tensor simply consists in enqueuing new kernels to a stream, as long as they are properly sequenced with the previous ops (same stream, or through events).

    Hence: the PG should mark the future complete when it is able to ensure that the user can schedule new kernels and that they can be properly ordered with the transfer

    Note that there is no requirement for the future to be marked complete immediately when the PG has enqueued its operations: that’s just a lower bound. The future can also be marked complete at any later point in time, and the above condition would still hold.

    This means that different PGs could have different approaches:

    • The NCCL PG could mark the future complete immediately after the ncclFoo function call returns, as at that point NCCL has already enqueued all its ops. Hence the future is already completed when it’s returned.
    • The UCC PG (not part of PyTorch, but used as an example) could wait until the operations actually flush on the stream before marking it as complete. Hence the cudaEvents contained in the future (see below) will already have occurred when it’s marked complete. (In such a case, non-blocking wait would behave the same as blocking wait, since both are no-ops wrt CUDA, but that’s fine).
    • In RPC, we would mark it complete when we receive the CPU metadata from the callee’s response (which contains number and sizes of the tensors), without waiting for the CUDA ops to flush.
  • In order for the user of the future to be able to wait for the async CUDA ops to flush, we allow the .wait() method to be either blocking or non-blocking (the former should be the default, as it’s safer, while the latter is for power users). The blocking version will block the user’s calling thread, whereas the async version will just synchronize the user’s current streams so that any new ops on them will run after the async CUDA transfers.

    • Note: the “current stream” concept is PyTorch-specific. It consists in a thread-local variable that identifies one stream for each device, and all CUDA ops implicitly run on that stream.
  • The .then() method will only have a non-blocking version, since it’s a more advanced API (less likely to be used by inexperienced users) and because it’s harder to support a blocking version with NCCL (requires extra watchdog thread, in which to run the callback) and RPC. Before running the callbacks we’ll set the current streams and synchronize them with the PG’s streams, so that any op launched by the callback will naturally synchronize with the transfer.

    • It could still be possible to implement a blocking version later on, if we deem that the demand for it outweighs the difficulties in implementing it.
  • In order for all of the above to work, all we need is for the CudaFuture class to contain one cudaEvent for each device.

    • Just before the future is marked complete, these events are recorded in the PG’s I/O streams.
    • When the blocking wait is called, we call cudaEventSynchronize on these events.
    • When the non-blocking wait is called, we call cudaStreamWaitEvent on the user’s current streams with those events.
    • When we run a .then() callback, we pick new streams from the ATen stream pool, add cudaStreamWaitEvent on them with those events, set these streams as current, and launch the callback.

Implementation plan:

  • We intend to add the non_blocking flag to ivalue::Future, even though it would have no effect there (as it's CPU-only). However, that would allow us to keep the CudaFuture class as a private subclass, which we downcast to a regular ivalue::Future when we return it. This way a uniform interface is presented, and no extra Python or TorchScript bindings are necessary.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd

@lw lw added module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: c10d Issues/PRs related to collective communications and process groups labels Nov 20, 2020
@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 20, 2020
@mrshenli mrshenli added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 20, 2020
@idoh
Copy link

idoh commented Nov 20, 2020

@lw Thank you for the great TensorPipe framework. I'm looking forward to the RPC CUDA support. Do you think this will be released in PyTorch v1.8?

@lw
Copy link
Contributor Author

lw commented Nov 26, 2020

Do you think this will be released in PyTorch v1.8?

That's our current intention, and I think we're on track to do so.

@lw
Copy link
Contributor Author

lw commented Nov 26, 2020

BTW, I just sent out a stack that extracts a generic CUDAFuture from FutureNCCL. It starts at #48495.

When doing that I realized that "not all streams" are equal, and in fact there's both high- and low-priority streams, and we need to specify which type we want when we obtain them from the pool. Unless we hardcode one of these types, we may need to allow the user (or the creator of the Future, which may be the ProcessGroup, or the RPC agent) to specify which priority it wants.

Another option that I'm still evaluating is whether it could make sense to run the callbacks inside the streams that were current when the user called then() to install the callback. It may be a bit less intuitive (is it?) but on the other hand it gives full control to the user of which streams are used for all operations and avoids the global ATen pool altogether. WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: c10d Issues/PRs related to collective communications and process groups module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants