[RFC] CUDA-aware future for distributed #48305
Labels
module: c10d
Issues/PRs related to collective communications and process groups
module: rpc
Related to RPC, distributed autograd, RRef, and distributed optimizer
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
We're working on adding CUDA support to the RPC module, and we're revamping the ProcessGroups by having them return Futures rather than Works. For those purposes we're designing a general-purpose Future class that can handle CUDA tensors (and thus work with CUDA streams). The FutureNCCL class would be merged with/replaced by this one.
Here is our current proposal, on which we welcome any kind of feedback!
The main point is: when do we mark such futures as complete? what are the requirements for that to happen? One option: we mark them complete when the value in them can be used by the user, but, in CUDA, “using” a tensor simply consists in enqueuing new kernels to a stream, as long as they are properly sequenced with the previous ops (same stream, or through events).
Hence: the PG should mark the future complete when it is able to ensure that the user can schedule new kernels and that they can be properly ordered with the transfer
Note that there is no requirement for the future to be marked complete immediately when the PG has enqueued its operations: that’s just a lower bound. The future can also be marked complete at any later point in time, and the above condition would still hold.
This means that different PGs could have different approaches:
In order for the user of the future to be able to wait for the async CUDA ops to flush, we allow the .wait() method to be either blocking or non-blocking (the former should be the default, as it’s safer, while the latter is for power users). The blocking version will block the user’s calling thread, whereas the async version will just synchronize the user’s current streams so that any new ops on them will run after the async CUDA transfers.
The .then() method will only have a non-blocking version, since it’s a more advanced API (less likely to be used by inexperienced users) and because it’s harder to support a blocking version with NCCL (requires extra watchdog thread, in which to run the callback) and RPC. Before running the callbacks we’ll set the current streams and synchronize them with the PG’s streams, so that any op launched by the callback will naturally synchronize with the transfer.
In order for all of the above to work, all we need is for the CudaFuture class to contain one cudaEvent for each device.
Implementation plan:
non_blocking
flag to ivalue::Future, even though it would have no effect there (as it's CPU-only). However, that would allow us to keep the CudaFuture class as a private subclass, which we downcast to a regular ivalue::Future when we return it. This way a uniform interface is presented, and no extra Python or TorchScript bindings are necessary.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd
The text was updated successfully, but these errors were encountered: