Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor] Allow DTensor support cuda-like device #102468

Closed
wants to merge 3 commits into from

Conversation

shaoyf42
Copy link
Contributor

@shaoyf42 shaoyf42 commented May 29, 2023

Allow DTensor support cuda-like device, fix #102442

Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example #101914 and #101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!

  1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So _default_backend_for_device is added to Backend. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
  2. Adding _device_handle to DeviceMesh for cuda-like devices, similar to what is set in FSDP. When _device_handle is not None, the device has similar behavior to cuda. In this way, functions like torch.cuda.device_count() need to be modified to device_mesh._device_handle.device_count().

@pytorch-bot
Copy link

pytorch-bot bot commented May 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102468

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 0afebac:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented May 29, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@shaoyf42
Copy link
Contributor Author

@wanchaol Could you take a look?

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! It seems interesting that we want cuda' device but not nccl` communicator. I left a few comments about the integration piece, would love to learn more about what's the detailed use case :)

I am also trying out use c10d's dispatchable backend to make it easier for backend integration #102336, I'll land this today and we can see if there's additional gaps to make your case work in this PR?

@@ -190,6 +195,17 @@ def __new__(cls, name: str):
value = name.lower()
return value

@classmethod
def get_default_backend_for_device(cls, device: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on a PR that allows easier backend integration with custom backend #102336 so I am not sure if we would still need to change the c10d backend to set for the default backend for device. I assume if you register the custom backend, it should automatically override the backend config if you initialize the world_pg first?

I think maybe we can allow passing in the device_type to device_mesh be sth like cuda:non-nccl and pass this to the init_process_group call

elif ":" in backend.lower():

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually if we don't specify anything in init_process_group, it would not take the custom registered backend and get it initialized... it seems like we do need this map that maps from device_type to backend, cc @H-Huang does it make sense to have this reverse map for custom backends?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume if you register the custom backend, it should automatically override the backend config if you initialize the world_pg first?

I think maybe we can allow passing in the device_type to device_mesh be sth like cuda:non-nccl and pass this to the init_process_group call

Oh actually if we don't specify anything in init_process_group, it would not take the custom registered backend and get it initialized

it seems like we do need this map that maps from device_type to backend,

As mentioned here, after registered a custom backend, there are two ways to initialize the backend (indluding custom backend):

  1. User use init_process_group before use them, or passes backend to init_process_group.
  2. Maintain a map from device_type to backend, which is updated when registering the custom backend. When we don't specify anything in init_process_group , it determines whether each device and backend in the map is available, and then uses the available device-backend pairs to initialize the processgroup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, a unified function can be used to determine whether the backend is available, such as is_backend_available, instead of a special case for each backend. is_backend_available supports the unified judgment of built-in and third-party registered backends. I have an implementation #101945

@@ -107,6 +107,9 @@ def __init__(
_init_process_groups: bool = True,
) -> None:
self.device_type = device_type
self._device_handle = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite like the device_handle thing, more like a hack to me specialized to cuda. if you are inventing a new backend that conforms with cuda, shouldn't it also have a identical call to torch.cuda.set_device, or maybe we can use CUDA_VISIBLE_DEVICES if it's hard to hijack into torch.cuda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there's additional gaps to make your case work in this PR?

We hope that the distributed nature of PyTorch can support third-party devices and third-party backends, not just third-party backends.

In fact, our device and cuda have the same semantics, so it will also be consistent with the cuda interface, which is a natural approach. For example, xpu uses torch._register_device_module("xpu", current_module) to register its extension to torch.xpu, and then we can use torch.xpu.device_count() and torch.xpu.get_rng_state() realize the same function as in cuda. Our implementation is similar to xpu, so for this device using torch.custom_device, we want use _device_handle instead of directly using torch.cuda to support cuda-like devices which like in FSDP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this make sense then, could you rebase the PR and fix the merge conflict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebase the PR. In addition, the situation that init_process_group does not initialize custom_backend by default can be discussed and resolved in the next PR.

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 31, 2023
@shaoyf42 shaoyf42 requested a review from wanchaol June 3, 2023 15:20
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I have one more suggestion before we merge it

@@ -101,6 +101,9 @@ def __init__(
_init_process_groups: bool = True,
) -> None:
self.device_type = device_type
self._device_handle = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I would suggest is that let's not try to attach device_handle to self, we need to be a bit careful when adding additional attribute to DeviceMesh, as one thing that is on the radar is that we want to make sure DTensor is pickable, so additional attribute there seems not quite ideal.

We can simply create _device_handle and use it in the __init__ without saving it to device mesh

if device_mesh.device_type == "cuda":
torch.cuda.set_rng_state(new_state)

if device_mesh._device_handle:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly in here, we can create device_handle again so that we don't need to save it to device_mesh itself.

@shaoyf42 shaoyf42 requested a review from fduwjj as a code owner June 7, 2023 03:55
@shaoyf42 shaoyf42 requested a review from wanchaol June 7, 2023 09:41
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for addressing the comments!

@wanchaol
Copy link
Contributor

wanchaol commented Jun 7, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 7, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DTensor] Allow DTensor support third-party device
5 participants