Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds DLPack support #57110

Closed
wants to merge 14 commits into from
Closed

Adds DLPack support #57110

wants to merge 14 commits into from

Conversation

emcastillo
Copy link
Collaborator

@emcastillo emcastillo commented Apr 28, 2021

Partially Fixes #55090
Depends on #55365

Inspired by dmlc/dlpack#57 (comment)

Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an ExternalStream object like the one we have in CuPy?

TODO: Add tests

Would like some feedback as this design needs quite a few iterations
@rgommers @leofang

cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi @heitorschueroff

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 28, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2059638 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@rgommers
Copy link
Collaborator

Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an ExternalStream object like the one we have in CuPy?

It shouldn't require a public object I think - I suspect that'd be a bigger discussion (not sure though, @mruberry?).

Looking at the CuPy version, I think you want a private version (implemented in C++ and with Python bindings so you can use it in Tensor.__dlpack__) of:

class ExternalStream(BaseStream):

    """CUDA stream.
    This class allows to use external streams in CuPy by providing the
    stream pointer obtained from the CUDA runtime call.
    The user is in charge of managing the life-cycle of the stream.
    Args:
        ptr (intptr_t): Address of the `cudaStream_t` object.
    Attributes:
        ~Stream.ptr (intptr_t): Raw stream handle.
    """

    def __init__(self, ptr):
        self.ptr = ptr

Is that about what you were thinking?

Copy link
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @emcastillo, looks like a great start.

torch/_tensor.py Outdated Show resolved Hide resolved
torch/_tensor.py Outdated Show resolved Hide resolved
torch/_tensor.py Outdated
a `synchronize` method. Optional.
"""
if isinstance(stream, torch.cuda.Stream) or hasattr(stream, 'synchronize'):
stream.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary, but maybe the right API is missing at the Python level here to do this asynchronously? From dmlc/dlpack#57 (comment):

"In cases where both sides uses their own stream, async exchange can still be done by stream dependency queing":

// event can also be created on the fly, or create a synchronizer object and cache it.
// We could build auxiliary function that can be called from python side if that helps the frameworks
void PushStreamDep(cudaStream_t src, cudaStream dst) {
    cudaEvent_t event;
    cudaEventCreate(&event);
    cudaEventRecord(&event ,src);
    cudaStreamWaitForEvent(dst, event);
   cudaEventDestroy(&event);
}

torch/_tensor.py Outdated Show resolved Hide resolved
torch/utils/dlpack.py Outdated Show resolved Hide resolved
torch/utils/dlpack.py Show resolved Hide resolved
@emcastillo
Copy link
Collaborator Author

Regarding the ExternalStream, I will send a PR to implement it in PyTorch so we can use it here.
Thanks!

@mruberry
Copy link
Collaborator

mruberry commented May 2, 2021

Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an ExternalStream object like the one we have in CuPy?

It shouldn't require a public object I think - I suspect that'd be a bigger discussion (not sure though, @mruberry?).

Keeping things private to start, if possible, is always preferable as it gives us more flexibility in the future. If there's a compelling reason to make it public we can always do so, of course, but you'll have to educate me ;)

facebook-github-bot pushed a commit that referenced this pull request Jun 4, 2021
Summary:
This is required in #57110 (comment)

We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api.

cc mruberry rgommers leofang asi1024 kmaehashi

Pull Request resolved: #57781

Reviewed By: mrshenli

Differential Revision: D28326365

Pulled By: ezyang

fbshipit-source-id: b67858c8033949951b49a3d319f649884dfd0a91
deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
This is required in pytorch#57110 (comment)

We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api.

cc mruberry rgommers leofang asi1024 kmaehashi

Pull Request resolved: pytorch#57781

Reviewed By: mrshenli

Differential Revision: D28326365

Pulled By: ezyang

fbshipit-source-id: b67858c8033949951b49a3d319f649884dfd0a91
@emcastillo
Copy link
Collaborator Author

@rgommers I updated the PR after #59527 was merged :).
I tried to address all your concerns, can I get a 2nd review, please?
I also tried to implement the two streams synchronization but did it python side instead of C++.
I didn't think that increasing the libtorch API with a function that would be used only in this case was worth it if it could be done python side.

Thank you!

@leofang
Copy link
Contributor

leofang commented Jun 18, 2021

@kmaehashi and I would like to raise this discussion: as demonstrated in @emcastillo's current design, from_dlpack() can support both pycapsule objects (for backward compatibility if a library has supported DLPack under this name) and any protocol-complaint object that comes with __dlpack__ and __dlpack_device__. The question is if this is OK to everyone.

@rgommers
Copy link
Collaborator

rgommers commented Jun 20, 2021

@kmaehashi and I would like to raise this discussion: as demonstrated in @emcastillo's current design, from_dlpack() can support both pycapsule objects (for backward compatibility if a library has supported DLPack under this name) and any protocol-complaint object that comes with __dlpack__ and __dlpack_device__. The question is if this is OK to everyone.

Thanks for bringing this up @leofang. This seems fine to me, since it's a superset of what's in the array API standard, so there's no conflict. Most libraries will support a superset for other functionality as well, for historical or other reasons.

I would recommend that if this is done, the documentation emphasizes that the capsule approach is there only for convenience to support libraries that support the old-style to_dlpack and not yet __dlpack__. Because the capsule approach is a bit less safe (capsule may not be consumed more than once), takes an extra line of code, and the stream parameter cannot be used.

@rgommers rgommers added the module: python array api Issues related to the Python Array API label Jun 20, 2021
Copy link
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @emcastillo.

I also tried to implement the two streams synchronization but did it python side instead of C++.
I didn't think that increasing the libtorch API with a function that would be used only in this case was worth it if it could be done python side.

I agree, this makes sense. Having ExternalStream available makes this change quite nice and small.

torch/_tensor.py Show resolved Hide resolved
torch/_tensor.py Outdated Show resolved Hide resolved
torch/_tensor.py Outdated Show resolved Hide resolved
# CPU = 1 CPU_PINNED = 3 OPENCL = 4 VULKAN = 7
# METAL = 8 VPI = 9
dlpack_ids = {'cpu': 1, 'cuda': 2, 'rocm': 10}
idx = self.device.index if self.device.index is not None else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has TODO's. I think it would be nice if this returned Tuple[enum.IntEnum, int] as in the spec: https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-device-self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @rgommers meant is to change the return type of this function:

def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

For __dlpack_device__ whether a device has the concept of stream/queue doesn't matter. For __dlpack__ stream can be Any:
https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-self-stream-none

torch/utils/dlpack.py Outdated Show resolved Hide resolved
torch/utils/dlpack.py Show resolved Hide resolved
@emcastillo emcastillo changed the title [WIP] Add support for dlpack to torch.tensor Add support for dlpack to torch.tensor Jun 21, 2021
@emcastillo
Copy link
Collaborator Author

I think I addressed all the review concerns, also added some small tests to verify the behavior.
Can I take another look?
This should be close to landing, but maybe some small fixes or tweaks might be required.

@anjali411 anjali411 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 21, 2021
@rgommers
Copy link
Collaborator

rgommers commented Jun 21, 2021

There's a bunch of test failures. Not sure about the xla ones.

For the __torch_function__ ones, the dunder methods need a snippet like this at the top of the method:

        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)

One other thing to discuss: from_dlpack is left in the torch.utils.dlpack namespace here, but according to the array API standard we'd want it in the main namespace. I think it fits there too, next to from_numpy and (soon) frombuffer.

@emcastillo emcastillo force-pushed the add-py-dlpack branch 3 times, most recently from 379e3ac to 9764886 Compare June 22, 2021 00:53
Copy link
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks @emcastillo!

@emcastillo
Copy link
Collaborator Author

Thank you all for your advice and for taking the time to thoroughly review my horrible initial implementation 😅

@mruberry
Copy link
Collaborator

mruberry commented Sep 7, 2021

Doh! Sorry, @emcastillo, looks like this pick up a merge conflict. Would just rebase it and ping me so I can merge it?

@emcastillo
Copy link
Collaborator Author

@mruberry rebased!

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 1cb3507.

@@ -759,6 +759,8 @@ def compiled_with_cxx11_abi():
quantized_lstm = torch.ops.aten.quantized_lstm
quantized_gru = torch.ops.aten.quantized_gru

from torch.utils.dlpack import from_dlpack, to_dlpack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some docs for DLPack in gh-70437. I'd forgotten about this. Looked back at my review comment, and it mentioned just from_dlpack. I think adding to_dlpack to the main namespace was probably unnecessary, given that it shouldn't be used when all libraries support __dlpack__. Anyway, no action needed - mostly a comment to self.

take-cheeze added a commit to take-cheeze/cupy that referenced this pull request Nov 27, 2023
pytorch/pytorch#57110

Co-authored-by: Leo Fang <leo80042@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: python array api Issues related to the Python Array API open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement improved DLPack support
7 participants