diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 5ffd4b4fb088..b6e78ff7a680 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -4641,6 +4641,41 @@ def test_nccl_barrier_timeout_new_group_non_member(self): with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): c10d.new_group([0], timeout=timedelta(seconds=1)) + @requires_nccl() + def test_nccl_barrier_device_ids(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store) + + c10d.barrier(device_ids=[self.rank]) + + @requires_nccl() + def test_nccl_barrier_device_ids_function_argument(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store) + + with self.assertRaisesRegex(RuntimeError, "Invalid function argument"): + c10d.barrier(device_ids=self.rank) + + @requires_gloo() + def test_gloo_barrier_device_ids(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", + rank=self.rank, + world_size=self.world_size, + store=store) + + with self.assertRaisesRegex(RuntimeError, "device_ids not supported"): + c10d.barrier(device_ids=[self.rank]) + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cd9a0f7d46a9..5ac2c0a8315d 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -77,6 +77,7 @@ class ReduceScatterOptions: timeout: timedelta class BarrierOptions: + device_ids: List[int] timeout: timedelta class AllToAllOptions: diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index b31d44a1d295..76b466c91f10 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -345,6 +345,7 @@ They are used in specifying strategies for reduction collectives, e.g., py::class_<::c10d::BarrierOptions>(module, "BarrierOptions") .def(py::init<>()) + .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout); py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index a11a12c6bf82..c4b8201085b5 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -17,6 +17,7 @@ AllreduceOptions, AllreduceCoalescedOptions, AllToAllOptions, + BarrierOptions, BroadcastOptions, FileStore, GatherOptions, @@ -2370,8 +2371,11 @@ def all_to_all(output_tensor_list, work.wait() + def barrier(group=GroupMember.WORLD, - async_op=False): + async_op=False, + device_ids=None): + """ Synchronizes all processes. @@ -2382,6 +2386,8 @@ def barrier(group=GroupMember.WORLD, group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. + Valid only for NCCL backend. Returns: Async work handle, if async_op is set to True. @@ -2390,11 +2396,22 @@ def barrier(group=GroupMember.WORLD, if _rank_not_in_group(group): return + opts = BarrierOptions() + if device_ids is not None: + if get_backend(group) != Backend.NCCL: + raise RuntimeError("Function argument device_ids not supported " + "for the selected backend {}".format(get_backend(group))) + if isinstance(device_ids, list): + opts.device_ids = device_ids + else: + raise RuntimeError("Invalid function argument: " + "device_ids type should be List[int]") + if group is None: default_pg = _get_default_group() - work = default_pg.barrier() + work = default_pg.barrier(opts=opts) else: - work = group.barrier() + work = group.barrier(opts=opts) if async_op: return work diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 01ce71afd388..b9ac5aa77150 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -1409,7 +1409,13 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; - if (usedDeviceIdxs_.empty()) { + + // Use user defined GPU device ids if provided + if (!opts.device_ids.empty()) { + for (auto device : opts.device_ids) { + devices.push_back(at::Device(at::DeviceType::CUDA, device)); + } + } else if (usedDeviceIdxs_.empty()) { // This means there is not yet a NCCL collective being called // Here we have to use the best guesses and will use a single GPU to call // allreduce to achieve barrier. diff --git a/torch/lib/c10d/Types.hpp b/torch/lib/c10d/Types.hpp index 03b2e59e4295..a5a0d5fa20df 100644 --- a/torch/lib/c10d/Types.hpp +++ b/torch/lib/c10d/Types.hpp @@ -62,6 +62,7 @@ struct AllToAllOptions { }; struct BarrierOptions { + std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; };