Skip to content

Commit

Permalink
Add .boxed() to c10d::ProcessGroup and c10d::Work's pybind (pytorch#1…
Browse files Browse the repository at this point in the history
…11997)

Summary:
When passed from C++ to Python, `c10d::ProcessGroup` and `c10d::Work` are automatically converted to their pybind class which can't be used for dispatcher ops. `.boxed()` exposes `c10d::ProcessGroup` and `c10d::Work` as boxed custom class object to Python.

```python
import tempfile

import torch
import torch.distributed as dist

if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
```

```
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111997
Approved by: https://github.com/lw
  • Loading branch information
yifuwang authored and xuhancn committed Nov 8, 2023
1 parent 0e9ae15 commit 0558f63
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
7 changes: 7 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional, overload, Tuple, Union

from torch import Tensor
from torch._C import ScriptObject
from torch.futures import Future

# This module is defined in torch/csrc/distributed/c10d/init.cpp
Expand Down Expand Up @@ -208,6 +209,9 @@ class Work:
def _source_rank(self) -> int: ...
def result(self) -> List[Tensor]: ...
def synchronize(self): ...
def boxed(self) -> ScriptObject: ...
@staticmethod
def unbox(obj: ScriptObject) -> Work: ...

class ProcessGroup:
class Options: ...
Expand Down Expand Up @@ -378,6 +382,9 @@ class ProcessGroup:
) -> Work: ...
def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
def barrier(self, opts=...) -> Work: ...
def boxed(self) -> ScriptObject: ...
@staticmethod
def unbox(obj: ScriptObject) -> ProcessGroup: ...

class ProcessGroupRoundRobin(ProcessGroup): ...

Expand Down
23 changes: 21 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,15 @@ The hook must have the following signature:
.def_property_readonly(
"group_name",
&::c10d::ProcessGroup::getGroupName,
"(Gets this process group name. It's cluster unique)");
"(Gets this process group name. It's cluster unique)")
.def("boxed", [](c10::intrusive_ptr<::c10d::ProcessGroup> self) {
return torch::jit::toPyObject(c10::IValue(std::move(self)));
})
.def_static("unbox", [](py::object obj) {
auto typePtr = torch::getCustomClass("__torch__.torch.classes.c10d.ProcessGroup");
auto ivalue = torch::jit::toIValue(obj, typePtr);
return ivalue.toCustomClass<::c10d::ProcessGroup>();
});

py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType")
.value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
Expand Down Expand Up @@ -2512,7 +2520,18 @@ Example::
.. warning ::
This API only works for NCCL backend for now and must set
NCCL_ENABLE_TIMING environment variable.
)");
)")
.def(
"boxed",
[](c10::intrusive_ptr<::c10d::Work> self) {
return torch::jit::toPyObject(c10::IValue(self));
})
.def_static("unbox", [](py::object obj) {
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.c10d.Work");
auto ivalue = torch::jit::toIValue(obj, typePtr);
return ivalue.toCustomClass<::c10d::Work>();
});

py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
.def(py::init<>())
Expand Down

0 comments on commit 0558f63

Please sign in to comment.