Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Split overloaded ray.train.torch.get_device into another get_devices API for multi-GPU worker setup #42314

Merged
merged 28 commits into from Jan 30, 2024

Conversation

woshiyyya
Copy link
Member

@woshiyyya woshiyyya commented Jan 11, 2024

Why are these changes needed?

The original ray.train.torch.get_device behaves inconsistently depending on the number of available devices on a ray train worker:

  • Single Device: Returns a single torch.device object.
  • Multiple Devices: Returns a list of torch.device objects.

The proposal involves two key changes:

  • Modification of ray.train.torch.get_device():
    • New Behavior: This function will always return a single torch.device object. If multiple devices are available, it will return the device with smaller id.
    • Rationale: Ensures consistent return type, simplifying the usage and handling in user code.
  • Introduction of ray.train.torch.get_devices():
    • Behavior: This new function will return a list of torch.device objects, representing all available devices.
    • Rationale: Provides a clear and explicit way to retrieve all devices for multi-gpu worker scenario.

Example:

Single-GPU workers:

def train_func():
    device = ray.train.torch.get_device() # torch.device("cuda:0") 
    devices = ray.train.torch.get_devices() # [torch.device("cuda:0")]

Multi-GPU workers (e.g. 2 gpus per worker):

def train_func():
    device = ray.train.torch.get_device() # torch.device("cuda:0")
    devices = ray.train.torch.get_devices() # [torch.device("cuda:0"), torch.device("cuda:1")]

Related issue number

Closes #42003, #38115

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

woshiyyya and others added 11 commits January 10, 2024 16:07
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@justinvyu justinvyu changed the title [Train] Split overloaded ray.train.torch.get_device into another get_devices for multi-GPU worker setup [Train] Split overloaded ray.train.torch.get_device into another get_devices API for multi-GPU worker setup Jan 29, 2024
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more round of suggestions for docs, then I think it's good.

  1. Add a small subsection called "Assigning multiple GPUs to a worker" that shows a small tested example recommending get_devices
  2. Add get_devices to the API reference here: https://anyscale-ray--42314.com.readthedocs.build/en/42314/train/api/api.html#pytorch

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
python/ray/air/_internal/torch_utils.py Show resolved Hide resolved

from ray.air._internal import torch_utils

record_extra_usage_tag(TagKey.TRAIN_TORCH_GET_DEVICE, "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a new TRAIN_TORCH_GET_DEVICES key.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah previously Justin mentioned we can use a single Key for these two APIs. But now I think it makes more sense to have a separate one to get more accurate telemetry data.

python/ray/air/_internal/torch_utils.py Outdated Show resolved Hide resolved
python/ray/train/torch/train_loop_utils.py Outdated Show resolved Hide resolved
@@ -63,11 +69,64 @@ def get_device() -> Union[torch.device, List[torch.device]]:
>>> # ray.get_gpu_ids() == [4,5]
>>> # torch.cuda.is_available() == True
>>> # get_device() == torch.device("cuda:4")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this wasn't actually working as expected before.

python/ray/train/torch/train_loop_utils.py Outdated Show resolved Hide resolved
woshiyyya and others added 5 commits January 29, 2024 13:39
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya
Copy link
Member Author

woshiyyya commented Jan 29, 2024

@ArturNiederfahrenhorst Can you take a look? The only change to rllib is in rllib/core/learner/torch/torch_learner.py

@ArturNiederfahrenhorst
Copy link
Contributor

Yep, sorry!

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved for RLlib changes

Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Copy link
Contributor

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For usage.proto

@matthewdeng matthewdeng merged commit d7a4f25 into ray-project:master Jan 30, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Data] [Train] Iter_torch_batches doesn't work with multi-GPU workers
8 participants