Skip to content

Commit

Permalink
[Train] Split overloaded ray.train.torch.get_device into another `g…
Browse files Browse the repository at this point in the history
…et_devices` API for multi-GPU worker setup (#42314)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
  • Loading branch information
woshiyyya committed Jan 30, 2024
1 parent f256fbf commit d7a4f25
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 76 deletions.
1 change: 1 addition & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PyTorch
:toctree: doc/

~train.torch.get_device
~train.torch.get_devices
~train.torch.prepare_model
~train.torch.prepare_data_loader
~train.torch.enable_reproducibility
Expand Down
34 changes: 33 additions & 1 deletion doc/source/train/user-guides/using-gpus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ You can get the associated devices with :meth:`ray.train.torch.get_device`.
device = get_device()
assert device == torch.device("cuda:0")


trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
Expand All @@ -71,6 +70,39 @@ You can get the associated devices with :meth:`ray.train.torch.get_device`.
)
trainer.fit()

Assigning multiple GPUs to a worker
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes you might want to allocate multiple GPUs for a worker. For example,
you can specify `resources_per_worker={"GPU": 2}` in the `ScalingConfig` if you want to
assign 2 GPUs for each worker.

You can get a list of associated devices with :meth:`ray.train.torch.get_devices`.

.. testcode::

import torch
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer, get_device, get_devices


def train_func(config):
assert torch.cuda.is_available()

device = get_device()
devices = get_devices()
assert device == torch.device("cuda:0")
assert devices == [torch.device("cuda:0"), torch.device("cuda:1")]

trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=1,
use_gpu=True,
resources_per_worker={"GPU": 2}
)
)
trainer.fit()


Setting the resources per worker
--------------------------------
Expand Down
13 changes: 6 additions & 7 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed


def get_device() -> Union[torch.device, List[torch.device]]:
"""Gets the correct torch device configured for this process.
def get_devices() -> List[torch.device]:
"""Gets the correct torch device list configured for this process.
Returns a list of devices if more than 1 GPU per worker
is requested.
Returns a list of torch CUDA devices allocated for the current worker.
If no GPUs are assigned, then it returns a list with a single CPU device.
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.
Expand Down Expand Up @@ -55,11 +55,10 @@ def get_device() -> Union[torch.device, List[torch.device]]:
device_ids.append(0)

devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids]
device = devices[0] if len(devices) == 1 else devices
else:
device = torch.device("cpu")
devices = [torch.device("cpu")]

return device
return devices


def convert_pandas_to_torch_tensor(
Expand Down
6 changes: 2 additions & 4 deletions python/ray/air/util/torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ray.actor import ActorHandle
from ray.train._internal.utils import get_address_and_port
from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices


class TorchDistributedWorker(ABC):
Expand Down Expand Up @@ -183,9 +183,7 @@ def _shutdown_torch_distributed():
return

# Clean up cuda memory.
devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
for device in devices:
with torch.cuda.device(device):
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def iter_torch_batches(

from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
get_device,
)
from ray.train.torch import get_device

if collate_fn is not None and (dtypes is not None or device != "auto"):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def collate_fn(batch: Dict[str, np.ndarray]):

# Test that we don't automatically set device if collate_fn is specified.
with patch(
"ray.air._internal.torch_utils.get_device", lambda: torch.device("cuda")
"ray.air._internal.torch_utils.get_devices", lambda: [torch.device("cuda")]
):
assert ray.air._internal.torch_utils.get_device().type == "cuda"
devices = ray.air._internal.torch_utils.get_devices()
assert devices[0].type == "cuda"

it.iter_batches = MagicMock()
for batch in it.iter_torch_batches(collate_fn=collate_fn):
Expand Down
15 changes: 3 additions & 12 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ def import_lightning(): # noqa: F402
LIGHTNING_REPORT_STAGE_KEY = "_report_on"


def get_worker_root_device():
"""Get the first torch device of the current worker if there are multiple."""
devices = ray.train.torch.get_device()
if isinstance(devices, list):
return devices[0]
else:
return devices


@PublicAPI(stability="beta")
class RayDDPStrategy(pl.strategies.DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.
Expand All @@ -77,7 +68,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand All @@ -101,7 +92,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -144,7 +135,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down
20 changes: 6 additions & 14 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def train_fn():
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
assert visible_devices == "1,2"

devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -92,9 +88,9 @@ def train_fn():
devices = list(rank_data.values())

if num_gpus_per_worker == 0.5:
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 1:
assert sorted(devices) == [0, 1]
assert sorted(devices) == [[0], [1]]
elif num_gpus_per_worker == 2:
assert sorted(devices[0]) == [0, 1]
else:
Expand All @@ -108,11 +104,7 @@ def train_fn():
def test_torch_get_device_dist(ray_2_node_2_gpu, num_gpus_per_worker, tmp_path):
@patch("torch.cuda.is_available", lambda: True)
def train_fn():
devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -138,12 +130,12 @@ def train_fn():
# 4 workers on node 1, 4 workers on node 2
# `ray.get_gpu_ids()` returns [0], [0], [1], [1] on node 1
# and [0], [0], [1], [1] on node 2
assert sorted(devices) == [0, 0, 0, 0, 1, 1, 1, 1]
assert sorted(devices) == [[0], [0], [0], [0], [1], [1], [1], [1]]
elif num_gpus_per_worker == 1:
# worker gpu topology:
# 2 workers on node 1, 2 workers on node 2
# `ray.get_gpu_ids()` returns [0], [1] on node 1 and [0], [1] on node 2
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 2:
# worker gpu topology:
# 1 workers on node 1, 1 workers on node 2
Expand Down
7 changes: 2 additions & 5 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,8 @@ def train_fn():
# the other is taken by the other sample) so device index should be 0.
# For the multiple GPU case, each worker has 2 visible devices so device
# index should be either 0 or 1. It doesn't matter which.
devices = train.torch.get_device()
if isinstance(devices, list):
assert sorted([device.index for device in devices]) == [0, 1]
else:
assert train.torch.get_device().index == 0
device_ids = sorted([device.index for device in train.torch.get_devices()])
assert device_ids in [[0], [0, 1]]

@ray.remote(num_cpus=0)
class TrialActor:
Expand Down
9 changes: 8 additions & 1 deletion python/ray/train/tests/test_train_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ def shutdown_only():
def run_torch():
from torch.utils.data import DataLoader, TensorDataset

from ray.train.torch import get_device, prepare_data_loader, prepare_model
from ray.train.torch import (
get_device,
get_devices,
prepare_data_loader,
prepare_model,
)

def train_func():
# Create dummy model and data loader
Expand All @@ -27,6 +32,7 @@ def train_func():
prepare_data_loader(dataloader)
prepare_model(model)
get_device()
get_devices()

trainer = TorchTrainer(
train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=False)
Expand Down Expand Up @@ -109,6 +115,7 @@ def test_torch_utility_usage_tags(shutdown_only, framework):
run_torch()
expected_tags = [
TagKey.TRAIN_TORCH_GET_DEVICE,
TagKey.TRAIN_TORCH_GET_DEVICES,
TagKey.TRAIN_TORCH_PREPARE_MODEL,
TagKey.TRAIN_TORCH_PREPARE_DATALOADER,
]
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
backward,
enable_reproducibility,
get_device,
get_devices,
prepare_data_loader,
prepare_model,
prepare_optimizer,
Expand All @@ -28,6 +29,7 @@
"TorchConfig",
"accelerate",
"get_device",
"get_devices",
"prepare_model",
"prepare_optimizer",
"prepare_data_loader",
Expand Down
10 changes: 3 additions & 7 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ def _setup_torch_process_group(


def _shutdown_torch(destroy_process_group=False):
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices

devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
if destroy_process_group:
dist.destroy_process_group()
if torch.cuda.is_available():
Expand All @@ -129,7 +127,7 @@ def _shutdown_torch(destroy_process_group=False):
def _set_torch_distributed_env_vars():
# Same env vars as in
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
from ray.air._internal.torch_utils import get_device
from ray.train.torch import get_device

context = ray.train.get_context()
os.environ["LOCAL_RANK"] = str(context.get_local_rank())
Expand All @@ -140,8 +138,6 @@ def _set_torch_distributed_env_vars():

# Makes sure Hugging Face Accelerate uses the correct device
device = get_device()
if isinstance(device, list):
device = device[0]
os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)


Expand Down
Loading

0 comments on commit d7a4f25

Please sign in to comment.