-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Manage CUDA Stream in TensorPipe RPC Agent #44084
Comments
For the first version where we will use a fixed-number of streams, what happens when there is no stream available, such as when the caller tries to retrieve them when sending an RPC? Would we block in this case for a stream to be returned to the pool or pick a stream that's already being used by some other operation? If we do the latter, could it cause some unintended synchronization? |
After some more offline discussion, we've refined the proposal. Let me try to recap what we said. Main new ideas:
Some more details When the caller sends a request:
When the caller receives a response:
|
What's the benefit of using two different stream pools? wouldn't always using agent's thread pool be sufficient? |
IIUC, it will just reuse one of the previous streams in the round-robin fashion. |
Yes, that would work too. I found it nice to keep it consistent with how we use threads, where we use the agent's thread pool for the user functions (the heavy lifting) but we use private TensorPipe threads for I/O. |
with @lw @beauby
API
There is no new API, but has a new behavior for sending/receiving CUDA tensors. The TensorPipe RPC agent will maintain a stream pool, and grab streams from the pool to 1) send tensors 2) receive tensors 3) run user functions.
The guarantees we offer are:
NCCL_BLOCKING_WAIT
dictates the behavior ofFuture.wait()
when the response contains CUDA tensors. IfNCCL_BLOCKING_WAIT=1
at the RPC initialization time,Future.wait()
means that all ops are done. Otherwise,Future.wait()
only means all CUDA ops are enqueued to the stream.Design
two principles:
Request on the Caller
cudaStream_t
(as TensorPipe cannot depend on PyTorch and hence cannot use PyTorch's Stream type)pipe->write
, either capture the stream in another lambda forcudaStreamAddCallback
or use a thread to synchronize the stream before destructing it (returning it to the pool).Request on the Callee
TensorPipe::Tensor.metadata
capture the device information, so that callee can know which device it will use from the descriptor before callingpipe->read
.pipe->read
.pipe->read
's on-complete callback, set those streams from the pool as current, and then directly launch user functions without synchronization. All comm ops are already enqueued into the current stream. If user functions need to use different streams, they need to explicitly synchronize the streams.Response on the callee
pipe->write
fires, we synchronize on those streams and only return them to their pools once all the enqueued work on them has completed.Response on the caller
NCCL_BLOCKING_WAIT=1
, mark theFuture
as complete after all streams are synchronized. Otherwise, mark theFuture
as complete when TensorPipepipe->read
returns.Discussion
c10::cuda::getStreamFromPool
in the first version, and then will add a new stream pool implement, which instead of using a fixed number of stream, it will create new ones when its queue is depleted.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar @jiayisuse @lw @beauby
The text was updated successfully, but these errors were encountered: