From 0274b06bb05591b654609da6f7368dd58001178b Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 1 Oct 2020 17:16:02 -0700 Subject: [PATCH 1/3] Docs improvement --- docs/source/distributed.rst | 58 ++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 9c2f4a103b5e..99e959f86627 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -305,28 +305,58 @@ as they should never be created manually, but they are guaranteed to support two Synchronous and asynchronous collective operations -------------------------------------------------- -Every collective operation function supports the following two kinds of operations: - -synchronous operation - the default mode, when ``async_op`` is set to False. -when the function returns, it is guaranteed that -the collective operation is performed (not necessarily completed if it's a CUDA op since all -CUDA ops are asynchronous), and any further function calls depending on the data of the -collective operation can be called. In the synchronous mode, the collective function does not -return anything - -asynchronous operation - when ``async_op`` is set to True. The collective operation function +Every collective operation function supports the following two kinds of operations, depending on the setting of the ``async_op`` flag passed into the collective: + +**Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. +When the function returns, it is guaranteed that +the collective operation is performed. In the case of CUDA operations, it is not guaranteed +that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any +further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, +function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of +synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream +synchronization, see `cuda semantics `__. +See the below script to see examples of differences in these semantics for CPU and CUDA operations. + +**Asynchronous operation** - when ``async_op`` is set to True. The collective operation function returns a distributed request object. In general, you don't need to create it manually and it is guaranteed to support two methods: -* ``is_completed()`` - returns True if the operation has finished -* ``wait()`` - will block the process until the operation is finished. +* ``is_completed()`` - in the case of CPU collectives, returns ``True`` if completed. In the case of CUDA operations, + returns ``True`` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the + default stream without further synchronization. +* ``wait()`` - in the case of CPU collectives, will block the process until the operation is completed. In the case + of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the + output can be utilized on the default stream without further synchronization. + +**Example** + +The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. +It shows the explicit need to synchronize when using collective outputs on different CUDA streams: + +:: + + # Code runs on each rank. + dist.init_process_group("nccl", rank=rank, world_size=2) + output = torch.tensor([rank]).cuda(rank) + s = torch.cuda.Stream() + handle = dist.all_reduce(output, async_op=True) + # Wait ensures the operation is enqueued, but not necessarily complete. + handle.wait() + # Using result on non-default stream. + with torch.cuda.stream(s): + s.wait_stream(torch.cuda.default_stream()) + output.add_(100) + if rank == 0: + # if the explicit call to wait_stream was omitted, the output below will be + # non-deterministically 1 or 101, depending on whether the allreduce overwrote + # the value after the add completed. + print(output) Collective functions -------------------- -.. autofunction:: broadcast - +.. autofunction:: broadcast .. autofunction:: broadcast_object_list .. autofunction:: all_reduce From 7374d0accb05aa3ce9b28c6da74358731689cf8a Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 1 Oct 2020 17:21:51 -0700 Subject: [PATCH 2/3] Update --- docs/source/distributed.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 99e959f86627..8e846542d8fd 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -357,6 +357,7 @@ Collective functions -------------------- .. autofunction:: broadcast + .. autofunction:: broadcast_object_list .. autofunction:: all_reduce From b8a050a0979c8943f06a56ed8fe49dbe4781e867 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 7 Oct 2020 13:26:23 -0700 Subject: [PATCH 3/3] Update --- docs/source/distributed.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 8e846542d8fd..eb4bc8725346 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -305,7 +305,8 @@ as they should never be created manually, but they are guaranteed to support two Synchronous and asynchronous collective operations -------------------------------------------------- -Every collective operation function supports the following two kinds of operations, depending on the setting of the ``async_op`` flag passed into the collective: +Every collective operation function supports the following two kinds of operations, +depending on the setting of the ``async_op`` flag passed into the collective: **Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. When the function returns, it is guaranteed that @@ -314,7 +315,7 @@ that the CUDA operation is completed, since CUDA operations are asynchronous. Fo further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream -synchronization, see `cuda semantics `__. +synchronization, see `CUDA Semantics `__. See the below script to see examples of differences in these semantics for CPU and CUDA operations. **Asynchronous operation** - when ``async_op`` is set to True. The collective operation function