New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix distributed documentation for asynchronous collective Work objects #45709
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,27 +305,58 @@ as they should never be created manually, but they are guaranteed to support two | |
|
||
Synchronous and asynchronous collective operations | ||
-------------------------------------------------- | ||
Every collective operation function supports the following two kinds of operations: | ||
|
||
synchronous operation - the default mode, when ``async_op`` is set to False. | ||
when the function returns, it is guaranteed that | ||
the collective operation is performed (not necessarily completed if it's a CUDA op since all | ||
CUDA ops are asynchronous), and any further function calls depending on the data of the | ||
collective operation can be called. In the synchronous mode, the collective function does not | ||
return anything | ||
|
||
asynchronous operation - when ``async_op`` is set to True. The collective operation function | ||
Every collective operation function supports the following two kinds of operations, depending on the setting of the ``async_op`` flag passed into the collective: | ||
|
||
**Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. | ||
When the function returns, it is guaranteed that | ||
the collective operation is performed. In the case of CUDA operations, it is not guaranteed | ||
that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any | ||
further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, | ||
function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of | ||
synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream | ||
synchronization, see `cuda semantics <https://pytorch.org/docs/stable/autograd.html#profiler>`__. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cuda semantics points to profiler, is this intentional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the catch! It should actually point to https://pytorch.org/docs/stable/notes/cuda.html |
||
See the below script to see examples of differences in these semantics for CPU and CUDA operations. | ||
|
||
**Asynchronous operation** - when ``async_op`` is set to True. The collective operation function | ||
returns a distributed request object. In general, you don't need to create it manually and it | ||
is guaranteed to support two methods: | ||
|
||
* ``is_completed()`` - returns True if the operation has finished | ||
* ``wait()`` - will block the process until the operation is finished. | ||
* ``is_completed()`` - in the case of CPU collectives, returns ``True`` if completed. In the case of CUDA operations, | ||
returns ``True`` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the | ||
default stream without further synchronization. | ||
* ``wait()`` - in the case of CPU collectives, will block the process until the operation is completed. In the case | ||
of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the | ||
output can be utilized on the default stream without further synchronization. | ||
|
||
**Example** | ||
|
||
The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. | ||
It shows the explicit need to synchronize when using collective outputs on different CUDA streams: | ||
|
||
:: | ||
|
||
# Code runs on each rank. | ||
dist.init_process_group("nccl", rank=rank, world_size=2) | ||
output = torch.tensor([rank]).cuda(rank) | ||
s = torch.cuda.Stream() | ||
handle = dist.all_reduce(output, async_op=True) | ||
# Wait ensures the operation is enqueued, but not necessarily complete. | ||
handle.wait() | ||
# Using result on non-default stream. | ||
with torch.cuda.stream(s): | ||
s.wait_stream(torch.cuda.default_stream()) | ||
output.add_(100) | ||
if rank == 0: | ||
# if the explicit call to wait_stream was omitted, the output below will be | ||
# non-deterministically 1 or 101, depending on whether the allreduce overwrote | ||
# the value after the add completed. | ||
print(output) | ||
|
||
|
||
Collective functions | ||
-------------------- | ||
|
||
.. autofunction:: broadcast | ||
.. autofunction:: broadcast | ||
|
||
.. autofunction:: broadcast_object_list | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's break this into shorter lines