Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Manage CUDA Stream in TensorPipe RPC Agent #44084

Open
mrshenli opened this issue Sep 3, 2020 · 5 comments
Open

[RFC] Manage CUDA Stream in TensorPipe RPC Agent #44084

mrshenli opened this issue Sep 3, 2020 · 5 comments
Labels
feature A request for a proper, new feature. module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mrshenli
Copy link
Contributor

mrshenli commented Sep 3, 2020

with @lw @beauby

API

There is no new API, but has a new behavior for sending/receiving CUDA tensors. The TensorPipe RPC agent will maintain a stream pool, and grab streams from the pool to 1) send tensors 2) receive tensors 3) run user functions.

The guarantees we offer are:

  1. RPC user functions are launched when all comm ops are enqueued to the current stream for all involved devices, but these is no guarantee any of the comm ops are done. This should be fine if user functions do not switch streams. However, if they do use different stream, they will need to explicitly synchronize.
  2. NCCL_BLOCKING_WAIT dictates the behavior of Future.wait() when the response contains CUDA tensors. If NCCL_BLOCKING_WAIT=1 at the RPC initialization time, Future.wait() means that all ops are done. Otherwise, Future.wait() only means all CUDA ops are enqueued to the stream.

Design

two principles:

  • No RPC CUDA tensor transfer should block ops in the current stream.
  • All ops in a stream must be cleared before it is returned to the pool to avoid pollute next user of this stream.

Request on the Caller

  • Determine the set of devices used by the CUDA tensors in the message.
  • Get the current streams for all these devices.
  • Record an event in each of these current streams (and then we won’t do anything more on these streams)
  • Pick a new fresh stream for each device from the stream pool
  • On each of these new streams, enqueue a wait for the above events
  • Pass these new streams to TensorPipe in the form of cudaStream_t (as TensorPipe cannot depend on PyTorch and hence cannot use PyTorch's Stream type)
  • In the on-complete callback of pipe->write, either capture the stream in another lambda for cudaStreamAddCallback or use a thread to synchronize the stream before destructing it (returning it to the pool).

Request on the Callee

  • Let TensorPipe::Tensor.metadata capture the device information, so that callee can know which device it will use from the descriptor before calling pipe->read.
  • Grab one stream for each distinct device, and pass those streams to TensorPipe pipe->read.
  • In pipe->read's on-complete callback, set those streams from the pool as current, and then directly launch user functions without synchronization. All comm ops are already enqueued into the current stream. If user functions need to use different streams, they need to explicitly synchronize the streams.

Response on the callee

  • Once the user function returns, preserve the current streams (they should still be the same ones from pool).
  • Do not synchronize on these streams and directly pass them to TensorPipe to send response message.
  • After the write callback from TensorPipe pipe->write fires, we synchronize on those streams and only return them to their pools once all the enqueued work on them has completed.

Response on the caller

  • When the response is coming in, pick a stream for each relevant device from the pool.
  • Pass these streams to TensorPipe to use for receiving the response tensors
  • After the read callback fires, synchronize with the work in these streams before returning them to the pool.
  • If NCCL_BLOCKING_WAIT=1, mark the Future as complete after all streams are synchronized. Otherwise, mark the Future as complete when TensorPipe pipe->read returns.

Discussion

  1. Do we need a "power-user" mode to disable stream pool on the caller? We can remember the current stream when sending the requests and use them to receive responses. This can give users more control.
  2. Regarding stream pool, will use c10::cuda::getStreamFromPool in the first version, and then will add a new stream pool implement, which instead of using a fixed number of stream, it will create new ones when its queue is depleted.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar @jiayisuse @lw @beauby

@mrshenli mrshenli added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent labels Sep 3, 2020
@rohan-varma
Copy link
Member

For the first version where we will use a fixed-number of streams, what happens when there is no stream available, such as when the caller tries to retrieve them when sending an RPC? Would we block in this case for a stream to be returned to the pool or pick a stream that's already being used by some other operation? If we do the latter, could it cause some unintended synchronization?

@lw
Copy link
Contributor

lw commented Sep 7, 2020

After some more offline discussion, we've refined the proposal. Let me try to recap what we said.

RPC CUDA diagram

Main new ideas:

  • We thought we had to work with the standard PyTorch Future class, which is not CUDA-aware, but we realized that we can subclass it to add the functionality we need without duplicating it.
  • We should try to provide a consistent user experience with the NCCL process group. In [NCCL] DDP communication hook: getFuture() without cudaStreamAddCallback #42335 a FutureNCCL was added to it which also subclasses the standard future. We can try to converge to a single future class for CUDA work, either right away (if easy enough) or in a later iteration.
  • We can have the agent keep its own pool stream only on the callee side, as that's where heavy computations could occur; on the caller, use the global aten stream pool.
  • Avoid synchronizing too much with the streams, in order to cater to power users who may wish to synchronize on their own, and because it's inefficient (cudaStreamAddCallback seems to have a very high latency, and cudaStreamSynchronize would block the calling thread). This also simplifies the design, as we'd return streams to the stream pool at the same time as we return worker threads to their pool, so we'll never have a situation where one pool is empty and the other one isn't.
  • Another rationale behind not synchronizing is that it may be fine to append new operations (transfers or computation) to a stream that still has stuff in it: if we have so many concurrent operations that 32 streams aren't enough, then there's so much contention on the device that additional contention on a stream may not be noticeable.
  • We do still want to keep the barrier for beginners as low as possible, meaning that unaware users shouldn't have to be too aware of CUDA synchronization and of streams to overlap compute with communication.
  • Having some flag (like CUDA_LAUNCH_BLOCKING or NCCL_BLOCKING_WAIT) that changes the synchronization behavior is a good idea. And, if we're subclassing the Future class, we could use keyword args for that, instead of env vars (like the non_blocking flag of torch.Tensor.copy_(...)).

Some more details

When the caller sends a request:

  • We record an event on the user's current streams, but that's all we do with them. Which means the communication will wait for the previous user work to complete, any work that the user enqueues later will not wait for the communication to finish (allowing computation/communication overlap).
  • We then pick a set of streams from the global aten pool and immediately wait for that event on them. The event used to do so (actually, one event per device) can be a member field of the agent, since it can be reused for each send.

When the caller receives a response:

  • We pick a set of streams from the global aten pool, transfer on them and record an event just after the transfer. Since we'll only wait for the event once the user explicitly waits on the future, we can't immediately reuse the event, and we need to keep it around, therefore we'll have a separate event stored on each CudaFuture class.
  • When calling CudaFuture::wait() we first call wait on the superclass (which will wait until we receive the response over TCP and have enqueued the transfers on the stream) and only then we wait for the event on the user's current streams. Since we use the streams that are current when calling wait(), we allow the user to use separate streams for sending and receiving, or to use the same streams but still be able to do something else with them in between.
  • We could have a bool non_blocking parameter in the CudaFuture::wait() method, which defaults to False. If it's false then, after waiting for the event on the stream, we also explicitly synchronize with the stream. This would help beginner users, who aren't familiar with stream semantics, to avoid forgetting to synchronize (but still allow power-users to opt-out and do their thing). Even though we said we shouldn't synchronize, it's fine to do it here, because the thread that we would be blocking is the user thread's (and not an internal thread) and that's exactly what the user wants us to do.
  • The Future's callbacks are a bit different and, as they could potentially be executing some heavy load, we want to deal with them in the same way as the remote user functions: we get for them a stream from the stream pool. (We do the same with threads: each future callback is offloaded to a worker thread).

@mrshenli
Copy link
Contributor Author

mrshenli commented Sep 8, 2020

We can have the agent keep its own pool stream only on the callee side, as that's where heavy computations could occur; on the caller, use the global aten stream pool.

What's the benefit of using two different stream pools? wouldn't always using agent's thread pool be sufficient?

@mrshenli
Copy link
Contributor Author

mrshenli commented Sep 8, 2020

For the first version where we will use a fixed-number of streams, what happens when there is no stream available, such as when the caller tries to retrieve them when sending an RPC? Would we block in this case for a stream to be returned to the pool or pick a stream that's already being used by some other operation? If we do the latter, could it cause some unintended synchronization?

IIUC, it will just reuse one of the previous streams in the round-robin fashion.

@lw
Copy link
Contributor

lw commented Sep 9, 2020

What's the benefit of using two different stream pools? wouldn't always using agent's thread pool be sufficient?

Yes, that would work too. I found it nice to keep it consistent with how we use threads, where we use the agent's thread pool for the user functions (the heavy lifting) but we use private TensorPipe threads for I/O.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants