Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed] Provide parameter to pass GPU ID in barrier function #49069

Closed
wants to merge 9 commits into from
35 changes: 35 additions & 0 deletions test/distributed/test_c10d.py
Expand Up @@ -4641,6 +4641,41 @@ def test_nccl_barrier_timeout_new_group_non_member(self):
with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
c10d.new_group([0], timeout=timedelta(seconds=1))

@requires_nccl()
def test_nccl_barrier_device_ids(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=store)

c10d.barrier(device_ids=[self.rank])

@requires_nccl()
def test_nccl_barrier_device_ids_function_argument(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=store)

with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
c10d.barrier(device_ids=self.rank)

@requires_gloo()
def test_gloo_barrier_device_ids(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=store)

with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
c10d.barrier(device_ids=[self.rank])

if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_distributed_c10d.pyi
Expand Up @@ -77,6 +77,7 @@ class ReduceScatterOptions:
timeout: timedelta

class BarrierOptions:
device_ids: List[int]
timeout: timedelta

class AllToAllOptions:
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -345,6 +345,7 @@ They are used in specifying strategies for reduction collectives, e.g.,

py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
.def(py::init<>())
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout);

py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
Expand Down
23 changes: 20 additions & 3 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -17,6 +17,7 @@
AllreduceOptions,
AllreduceCoalescedOptions,
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
FileStore,
GatherOptions,
Expand Down Expand Up @@ -2370,8 +2371,11 @@ def all_to_all(output_tensor_list,
work.wait()



def barrier(group=GroupMember.WORLD,
async_op=False):
async_op=False,
device_ids=None):

"""
Synchronizes all processes.

Expand All @@ -2382,6 +2386,8 @@ def barrier(group=GroupMember.WORLD,
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
device_ids ([int], optional): List of device/GPU ids.
Valid only for NCCL backend.

Returns:
Async work handle, if async_op is set to True.
Expand All @@ -2390,11 +2396,22 @@ def barrier(group=GroupMember.WORLD,
if _rank_not_in_group(group):
return

opts = BarrierOptions()
if device_ids is not None:
if get_backend(group) != Backend.NCCL:
raise RuntimeError("Function argument device_ids not supported "
"for the selected backend {}".format(get_backend(group)))
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:
raise RuntimeError("Invalid function argument: "
"device_ids type should be List[int]")
Comment on lines +2401 to +2408
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add tests for these invalid cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test case added.


if group is None:
default_pg = _get_default_group()
work = default_pg.barrier()
work = default_pg.barrier(opts=opts)
else:
work = group.barrier()
work = group.barrier(opts=opts)

if async_op:
return work
Expand Down
8 changes: 7 additions & 1 deletion torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1409,7 +1409,13 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
const BarrierOptions& opts) {
std::vector<at::Device> devices;
if (usedDeviceIdxs_.empty()) {

// Use user defined GPU device ids if provided
if (!opts.device_ids.empty()) {
for (auto device : opts.device_ids) {
devices.push_back(at::Device(at::DeviceType::CUDA, device));
}
} else if (usedDeviceIdxs_.empty()) {
// This means there is not yet a NCCL collective being called
// Here we have to use the best guesses and will use a single GPU to call
// allreduce to achieve barrier.
Expand Down
1 change: 1 addition & 0 deletions torch/lib/c10d/Types.hpp
Expand Up @@ -62,6 +62,7 @@ struct AllToAllOptions {
};

struct BarrierOptions {
std::vector<int> device_ids;
std::chrono::milliseconds timeout = kUnsetTimeout;
};

Expand Down