Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Split overloaded ray.train.torch.get_device into another get_devices API for multi-GPU worker setup #42314

Merged
merged 28 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed


def get_device() -> Union[torch.device, List[torch.device]]:
def get_devices() -> List[torch.device]:
"""Gets the correct torch device configured for this process.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Returns a list of devices if more than 1 GPU per worker
is requested.
Returns a list of torch CUDA devices allocated for the current worker.
If no GPUs are assigned, then it returns a list with a single CPU device.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.
Expand Down Expand Up @@ -55,11 +55,10 @@ def get_device() -> Union[torch.device, List[torch.device]]:
device_ids.append(0)

devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids]
device = devices[0] if len(devices) == 1 else devices
else:
device = torch.device("cpu")
devices = [torch.device("cpu")]
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

return device
return devices


def convert_pandas_to_torch_tensor(
Expand Down
6 changes: 2 additions & 4 deletions python/ray/air/util/torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ray.actor import ActorHandle
from ray.train._internal.utils import get_address_and_port
from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices


class TorchDistributedWorker(ABC):
Expand Down Expand Up @@ -183,9 +183,7 @@ def _shutdown_torch_distributed():
return

# Clean up cuda memory.
devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
for device in devices:
with torch.cuda.device(device):
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def iter_torch_batches(

from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
get_device,
)
from ray.train.torch import get_device

if collate_fn is not None and (dtypes is not None or device != "auto"):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def collate_fn(batch: Dict[str, np.ndarray]):

# Test that we don't automatically set device if collate_fn is specified.
with patch(
"ray.air._internal.torch_utils.get_device", lambda: torch.device("cuda")
"ray.air._internal.torch_utils.get_devices", lambda: [torch.device("cuda")]
):
assert ray.air._internal.torch_utils.get_device().type == "cuda"
devices = ray.air._internal.torch_utils.get_devices()
assert devices[0].type == "cuda"

it.iter_batches = MagicMock()
for batch in it.iter_torch_batches(collate_fn=collate_fn):
Expand Down
15 changes: 3 additions & 12 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ def import_lightning(): # noqa: F402
LIGHTNING_REPORT_STAGE_KEY = "_report_on"


def get_worker_root_device():
"""Get the first torch device of the current worker if there are multiple."""
devices = ray.train.torch.get_device()
if isinstance(devices, list):
return devices[0]
else:
return devices
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


@PublicAPI(stability="beta")
class RayDDPStrategy(pl.strategies.DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.
Expand All @@ -77,7 +68,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand All @@ -101,7 +92,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -144,7 +135,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down
20 changes: 6 additions & 14 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def train_fn():
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
assert visible_devices == "1,2"

devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -92,9 +88,9 @@ def train_fn():
devices = list(rank_data.values())

if num_gpus_per_worker == 0.5:
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 1:
assert sorted(devices) == [0, 1]
assert sorted(devices) == [[0], [1]]
elif num_gpus_per_worker == 2:
assert sorted(devices[0]) == [0, 1]
else:
Expand All @@ -108,11 +104,7 @@ def train_fn():
def test_torch_get_device_dist(ray_2_node_2_gpu, num_gpus_per_worker, tmp_path):
@patch("torch.cuda.is_available", lambda: True)
def train_fn():
devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -138,12 +130,12 @@ def train_fn():
# 4 workers on node 1, 4 workers on node 2
# `ray.get_gpu_ids()` returns [0], [0], [1], [1] on node 1
# and [0], [0], [1], [1] on node 2
assert sorted(devices) == [0, 0, 0, 0, 1, 1, 1, 1]
assert sorted(devices) == [[0], [0], [0], [0], [1], [1], [1], [1]]
elif num_gpus_per_worker == 1:
# worker gpu topology:
# 2 workers on node 1, 2 workers on node 2
# `ray.get_gpu_ids()` returns [0], [1] on node 1 and [0], [1] on node 2
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 2:
# worker gpu topology:
# 1 workers on node 1, 1 workers on node 2
Expand Down
7 changes: 2 additions & 5 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,8 @@ def train_fn():
# the other is taken by the other sample) so device index should be 0.
# For the multiple GPU case, each worker has 2 visible devices so device
# index should be either 0 or 1. It doesn't matter which.
devices = train.torch.get_device()
if isinstance(devices, list):
assert sorted([device.index for device in devices]) == [0, 1]
else:
assert train.torch.get_device().index == 0
device_ids = sorted([device.index for device in train.torch.get_devices()])
assert device_ids in [[0], [0, 1]]

@ray.remote(num_cpus=0)
class TrialActor:
Expand Down
8 changes: 7 additions & 1 deletion python/ray/train/tests/test_train_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ def shutdown_only():
def run_torch():
from torch.utils.data import DataLoader, TensorDataset

from ray.train.torch import get_device, prepare_data_loader, prepare_model
from ray.train.torch import (
get_device,
get_devices,
prepare_data_loader,
prepare_model,
)

def train_func():
# Create dummy model and data loader
Expand All @@ -27,6 +32,7 @@ def train_func():
prepare_data_loader(dataloader)
prepare_model(model)
get_device()
get_devices()

trainer = TorchTrainer(
train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=False)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
backward,
enable_reproducibility,
get_device,
get_devices,
prepare_data_loader,
prepare_model,
prepare_optimizer,
Expand All @@ -28,6 +29,7 @@
"TorchConfig",
"accelerate",
"get_device",
"get_devices",
"prepare_model",
"prepare_optimizer",
"prepare_data_loader",
Expand Down
10 changes: 3 additions & 7 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ def _setup_torch_process_group(


def _shutdown_torch(destroy_process_group=False):
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices

devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
if destroy_process_group:
dist.destroy_process_group()
if torch.cuda.is_available():
Expand All @@ -129,7 +127,7 @@ def _shutdown_torch(destroy_process_group=False):
def _set_torch_distributed_env_vars():
# Same env vars as in
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
from ray.air._internal.torch_utils import get_device
from ray.train.torch import get_device

context = ray.train.get_context()
os.environ["LOCAL_RANK"] = str(context.get_local_rank())
Expand All @@ -140,8 +138,6 @@ def _set_torch_distributed_env_vars():

# Makes sure Hugging Face Accelerate uses the correct device
device = get_device()
if isinstance(device, list):
device = device[0]
os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)


Expand Down
51 changes: 47 additions & 4 deletions python/ray/train/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@


@PublicAPI(stability="stable")
def get_device() -> Union[torch.device, List[torch.device]]:
def get_device() -> torch.device:
"""Gets the correct torch device configured for this process.

Returns a list of devices if more than 1 GPU per worker
is requested.
Returns the torch device for the current worker. If more than 1 GPU is
requested per worker, returns the device with the minimal device index.

.. note::

If you requested multiple GPUs per worker, and want to get
the full list of torch devices, please use
:meth:`~ray.train.torch.get_devices`.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.
Expand All @@ -63,11 +69,48 @@ def get_device() -> Union[torch.device, List[torch.device]]:
>>> # ray.get_gpu_ids() == [4,5]
>>> # torch.cuda.is_available() == True
>>> # get_device() == torch.device("cuda:4")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this wasn't actually working as expected before.



>>> # You can move model to device by:
>>> # model.to(ray.train.torch.get_device())
>>> #
>>> # instead of manually checking the device type:
>>> # model.to("cuda" if torch.cuda.is_available() else "cpu")
"""
from ray.air._internal import torch_utils

record_extra_usage_tag(TagKey.TRAIN_TORCH_GET_DEVICE, "1")
return torch_utils.get_devices()[0]


@PublicAPI(stability="alpha")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
def get_devices() -> List[torch.device]:
"""Gets the correct torch device list configured for this process.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.

Example:
>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
>>> # ray.get_gpu_ids() == [3]
>>> # torch.cuda.is_available() == True
>>> # get_devices() == [torch.device("cuda:0")]

>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4"
>>> # ray.get_gpu_ids() == [4]
>>> # torch.cuda.is_available() == True
>>> # get_devices() == [torch.device("cuda:4")]

>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"
>>> # ray.get_gpu_ids() == [4,5]
>>> # torch.cuda.is_available() == True
>>> # get_devices() == [torch.device("cuda:4"), torch.device("cuda:5")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a codeblock instead if everything is commented?

Copy link
Member Author

@woshiyyya woshiyyya Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think probably the current way is clearer than a tedious train_func + TorchTrainer code block, which may reduce the readability. And we want to illustrate multiple scenario in this case, thus we don't want to write multiple examples for this docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was thinking just ::testcode :skipif: True. Right now the code blocks look weird:

Screenshot 2024-01-29 at 11 21 57 AM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I got you. I've updated the codeblock.

"""

from ray.air._internal import torch_utils

record_extra_usage_tag(TagKey.TRAIN_TORCH_GET_DEVICE, "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a new TRAIN_TORCH_GET_DEVICES key.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah previously Justin mentioned we can use a single Key for these two APIs. But now I think it makes more sense to have a separate one to get more accurate telemetry data.

return torch_utils.get_device()
return torch_utils.get_devices()


@PublicAPI(stability="stable")
Expand Down
6 changes: 3 additions & 3 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
torch, nn = try_import_torch()

if torch:
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -298,12 +298,12 @@ def build(self) -> None:
# TODO (Kourosh): Instead of using _TorchAccelerator, we should use the public
# API in ray.train but allow for session to be None without any errors raised.
if self._use_gpu:
# get_device() returns the 0th device if
# get_devices() returns a list that contains the 0th device if
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
# it is called from outside of a Ray Train session. Its necessary to give
# the user the option to run on the gpu of their choice, so we enable that
# option here via the local gpu id scaling config parameter.
if self._distributed:
self._device = get_device()
self._device = get_devices()[0]
else:
assert self._local_gpu_idx < torch.cuda.device_count(), (
f"local_gpu_idx {self._local_gpu_idx} is not a valid GPU id or is "
Expand Down
Loading