Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .boxed() to c10d::ProcessGroup and c10d::Work's pybind #111997

Closed
wants to merge 3 commits into from

Conversation

yifuwang
Copy link
Contributor

@yifuwang yifuwang commented Oct 25, 2023

Stack from ghstack (oldest at bottom):

Summary:
When passed from C++ to Python, c10d::ProcessGroup and c10d::Work are automatically converted to their pybind class which can't be used for dispatcher ops. .boxed() exposes c10d::ProcessGroup and c10d::Work as boxed custom class object to Python.

import tempfile

import torch
import torch.distributed as dist


if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111997

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c889d69 with merge base 31c0ef9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yifuwang added a commit that referenced this pull request Oct 25, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 5017bef30b6618393fab2c2385f49858c9abf2f1
Pull Request resolved: #111997
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
yifuwang added a commit that referenced this pull request Oct 25, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b3bd9305f3ca1231b4eb29a51e8bef81b7e51ba0
Pull Request resolved: #111997
Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this out so promptly!

I have a hard time building PyTorch from source, but I copied that code into my own C++ extensions and it works!

I have however noticed that, inside my operator, I'm now receiving an instance of torch.classes.c10d.ProcessGroup, and I need to convert it back to a regular ProcessGroup, thus I also need an "unboxing" function.

Do you think it would be possible to add this as a static method to the pybind of c10d::ProcessGroup?

Based on my own trial and error, this would be the code:

c10::intrusive_ptr<c10d::ProcessGroup> unbox_process_group(
    const py::object& obj) {
  return torch::jit::toIValue(
             obj,
             c10::getCustomClassType<c10::intrusive_ptr<c10d::ProcessGroup>>())
      .toCustomClass<c10d::ProcessGroup>();
}

@yifuwang
Copy link
Contributor Author

I copied that code into my own C++ extensions and it works!

That's great to know! Thanks for testing it out!

Do you think it would be possible to add this as a static method to the pybind of c10d::ProcessGroup?

This is possible. I understand you want a convenient unboxing function on cpp side, or unboxing c10d::Work on python side (since you might want to wait from python). Curious if/why you want to unbox a c10d::ProcessGroup on python side since the unboxed version is already available.

@lw
Copy link
Contributor

lw commented Oct 26, 2023

My end-to-end use case is something like this:

def my_op(tensor, boxed_process_group):
    process_group = torch.distributed.ProcessGroup.unbox(boxed_process_group)
    torch.distributed.all_reduce(tensor, group=process_group)

lib = torch.library.Library("my_lib", "DEF")
lib.define("my_op(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group) -> Tensor")
lib.impl("my_op", my_op, "CUDA")

def main():
    torch.distributed.init_process_group("nccl")
    group = torch.distributed.new_group()
    input_ = torch.ones((1,), device="cuda")
    output = torch.ops.my_lib.my_op(input_, group.boxed())

In short: I have a Python function, which I need to invoke from some other Python code, however I need this call to go though the dispatcher, thus I need to register my function as an operator. Therefore I need to box the PG when invoking the operator, and I need to unbox it within the operator.

The reason I need the function to be an operator is because I want it to be intercepted by xFormers's selective activation checkpointing mechanism.

@yifuwang
Copy link
Contributor Author

@lw Got it! I overlooked the part where the custom kernels could be in Python. Will add unbox before landing.

Summary:
When passed from C++ to Python, `c10d::ProcessGroup` and `c10d::Work` are automatically converted to their pybind class which can't be used for dispatcher ops. `.boxed()` exposes `c10d::ProcessGroup` and `c10d::Work` as boxed custom class object to Python.

```python
import tempfile

import torch
import torch.distributed as dist


if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
``` 

```
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
yifuwang added a commit that referenced this pull request Nov 1, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3d6bfd9955fdcb42d99fe6030686d2724c7f9878
Pull Request resolved: #111997
@yifuwang
Copy link
Contributor Author

yifuwang commented Nov 2, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 2, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/yifuwang/11/head branch November 6, 2023 15:25
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…11997)

Summary:
When passed from C++ to Python, `c10d::ProcessGroup` and `c10d::Work` are automatically converted to their pybind class which can't be used for dispatcher ops. `.boxed()` exposes `c10d::ProcessGroup` and `c10d::Work` as boxed custom class object to Python.

```python
import tempfile

import torch
import torch.distributed as dist

if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
```

```
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111997
Approved by: https://github.com/lw
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…11997)

Summary:
When passed from C++ to Python, `c10d::ProcessGroup` and `c10d::Work` are automatically converted to their pybind class which can't be used for dispatcher ops. `.boxed()` exposes `c10d::ProcessGroup` and `c10d::Work` as boxed custom class object to Python.

```python
import tempfile

import torch
import torch.distributed as dist

if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
```

```
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111997
Approved by: https://github.com/lw
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Nov 19, 2023
…11997)

Summary:
When passed from C++ to Python, `c10d::ProcessGroup` and `c10d::Work` are automatically converted to their pybind class which can't be used for dispatcher ops. `.boxed()` exposes `c10d::ProcessGroup` and `c10d::Work` as boxed custom class object to Python.

```python
import tempfile

import torch
import torch.distributed as dist

if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        group = dist.group.WORLD
        print(group)
        print(group.boxed())
```

```
<torch.distributed.distributed_c10d.ProcessGroup object at 0x7fe42fb78d30>
ScriptObject <__torch__.torch.classes.c10d.ProcessGroup>
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111997
Approved by: https://github.com/lw
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (c10d) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants