Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Split overloaded ray.train.torch.get_device into another get_devices API for multi-GPU worker setup #42314

Merged
merged 28 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PyTorch
:toctree: doc/

~train.torch.get_device
~train.torch.get_devices
~train.torch.prepare_model
~train.torch.prepare_data_loader
~train.torch.enable_reproducibility
Expand Down
34 changes: 33 additions & 1 deletion doc/source/train/user-guides/using-gpus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ You can get the associated devices with :meth:`ray.train.torch.get_device`.
device = get_device()
assert device == torch.device("cuda:0")


trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
Expand All @@ -71,6 +70,39 @@ You can get the associated devices with :meth:`ray.train.torch.get_device`.
)
trainer.fit()

Assigning multiple GPUs to a worker
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes you might want to allocate multiple GPUs for a worker. For example,
you can specify `resources_per_worker={"GPU": 2}` in the `ScalingConfig` if you want to
assign 2 GPUs for each worker.

You can get a list of associated devices with :meth:`ray.train.torch.get_devices`.

.. testcode::

import torch
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer, get_device, get_devices


def train_func(config):
assert torch.cuda.is_available()

device = get_device()
devices = get_devices()
assert device = [torch.device("cuda:0")]
assert devices == [torch.device("cuda:0"), torch.device("cuda:1")]

trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=1,
use_gpu=True,
resources_per_worker={"GPU": 2}
)
)
trainer.fit()


Setting the resources per worker
--------------------------------
Expand Down
13 changes: 6 additions & 7 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed


def get_device() -> Union[torch.device, List[torch.device]]:
"""Gets the correct torch device configured for this process.
def get_devices() -> List[torch.device]:
"""Gets the correct torch device list configured for this process.

Returns a list of devices if more than 1 GPU per worker
is requested.
Returns a list of torch CUDA devices allocated for the current worker.
If no GPUs are assigned, then it returns a list with a single CPU device.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.
Expand Down Expand Up @@ -55,11 +55,10 @@ def get_device() -> Union[torch.device, List[torch.device]]:
device_ids.append(0)

devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids]
device = devices[0] if len(devices) == 1 else devices
else:
device = torch.device("cpu")
devices = [torch.device("cpu")]
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

return device
return devices


def convert_pandas_to_torch_tensor(
Expand Down
6 changes: 2 additions & 4 deletions python/ray/air/util/torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ray.actor import ActorHandle
from ray.train._internal.utils import get_address_and_port
from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices


class TorchDistributedWorker(ABC):
Expand Down Expand Up @@ -183,9 +183,7 @@ def _shutdown_torch_distributed():
return

# Clean up cuda memory.
devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
for device in devices:
with torch.cuda.device(device):
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def iter_torch_batches(

from ray.air._internal.torch_utils import (
convert_ndarray_batch_to_torch_tensor_batch,
get_device,
)
from ray.train.torch import get_device

if collate_fn is not None and (dtypes is not None or device != "auto"):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def collate_fn(batch: Dict[str, np.ndarray]):

# Test that we don't automatically set device if collate_fn is specified.
with patch(
"ray.air._internal.torch_utils.get_device", lambda: torch.device("cuda")
"ray.air._internal.torch_utils.get_devices", lambda: [torch.device("cuda")]
):
assert ray.air._internal.torch_utils.get_device().type == "cuda"
devices = ray.air._internal.torch_utils.get_devices()
assert devices[0].type == "cuda"

it.iter_batches = MagicMock()
for batch in it.iter_torch_batches(collate_fn=collate_fn):
Expand Down
15 changes: 3 additions & 12 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ def import_lightning(): # noqa: F402
LIGHTNING_REPORT_STAGE_KEY = "_report_on"


def get_worker_root_device():
"""Get the first torch device of the current worker if there are multiple."""
devices = ray.train.torch.get_device()
if isinstance(devices, list):
return devices[0]
else:
return devices
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


@PublicAPI(stability="beta")
class RayDDPStrategy(pl.strategies.DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.
Expand All @@ -77,7 +68,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand All @@ -101,7 +92,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -144,7 +135,7 @@ def __init__(self, *args, **kwargs):

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down
20 changes: 6 additions & 14 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def train_fn():
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
assert visible_devices == "1,2"

devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -92,9 +88,9 @@ def train_fn():
devices = list(rank_data.values())

if num_gpus_per_worker == 0.5:
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 1:
assert sorted(devices) == [0, 1]
assert sorted(devices) == [[0], [1]]
elif num_gpus_per_worker == 2:
assert sorted(devices[0]) == [0, 1]
else:
Expand All @@ -108,11 +104,7 @@ def train_fn():
def test_torch_get_device_dist(ray_2_node_2_gpu, num_gpus_per_worker, tmp_path):
@patch("torch.cuda.is_available", lambda: True)
def train_fn():
devices = (
sorted([device.index for device in train.torch.get_device()])
if num_gpus_per_worker > 1
else train.torch.get_device().index
)
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

trainer = TorchTrainer(
Expand All @@ -138,12 +130,12 @@ def train_fn():
# 4 workers on node 1, 4 workers on node 2
# `ray.get_gpu_ids()` returns [0], [0], [1], [1] on node 1
# and [0], [0], [1], [1] on node 2
assert sorted(devices) == [0, 0, 0, 0, 1, 1, 1, 1]
assert sorted(devices) == [[0], [0], [0], [0], [1], [1], [1], [1]]
elif num_gpus_per_worker == 1:
# worker gpu topology:
# 2 workers on node 1, 2 workers on node 2
# `ray.get_gpu_ids()` returns [0], [1] on node 1 and [0], [1] on node 2
assert sorted(devices) == [0, 0, 1, 1]
assert sorted(devices) == [[0], [0], [1], [1]]
elif num_gpus_per_worker == 2:
# worker gpu topology:
# 1 workers on node 1, 1 workers on node 2
Expand Down
7 changes: 2 additions & 5 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,8 @@ def train_fn():
# the other is taken by the other sample) so device index should be 0.
# For the multiple GPU case, each worker has 2 visible devices so device
# index should be either 0 or 1. It doesn't matter which.
devices = train.torch.get_device()
if isinstance(devices, list):
assert sorted([device.index for device in devices]) == [0, 1]
else:
assert train.torch.get_device().index == 0
device_ids = sorted([device.index for device in train.torch.get_devices()])
assert device_ids in [[0], [0, 1]]

@ray.remote(num_cpus=0)
class TrialActor:
Expand Down
9 changes: 8 additions & 1 deletion python/ray/train/tests/test_train_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ def shutdown_only():
def run_torch():
from torch.utils.data import DataLoader, TensorDataset

from ray.train.torch import get_device, prepare_data_loader, prepare_model
from ray.train.torch import (
get_device,
get_devices,
prepare_data_loader,
prepare_model,
)

def train_func():
# Create dummy model and data loader
Expand All @@ -27,6 +32,7 @@ def train_func():
prepare_data_loader(dataloader)
prepare_model(model)
get_device()
get_devices()

trainer = TorchTrainer(
train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=False)
Expand Down Expand Up @@ -109,6 +115,7 @@ def test_torch_utility_usage_tags(shutdown_only, framework):
run_torch()
expected_tags = [
TagKey.TRAIN_TORCH_GET_DEVICE,
TagKey.TRAIN_TORCH_GET_DEVICES,
TagKey.TRAIN_TORCH_PREPARE_MODEL,
TagKey.TRAIN_TORCH_PREPARE_DATALOADER,
]
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
backward,
enable_reproducibility,
get_device,
get_devices,
prepare_data_loader,
prepare_model,
prepare_optimizer,
Expand All @@ -28,6 +29,7 @@
"TorchConfig",
"accelerate",
"get_device",
"get_devices",
"prepare_model",
"prepare_optimizer",
"prepare_data_loader",
Expand Down
10 changes: 3 additions & 7 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ def _setup_torch_process_group(


def _shutdown_torch(destroy_process_group=False):
from ray.air._internal.torch_utils import get_device
from ray.air._internal.torch_utils import get_devices

devices = get_device()
if not isinstance(devices, list):
devices = [devices]
devices = get_devices()
if destroy_process_group:
dist.destroy_process_group()
if torch.cuda.is_available():
Expand All @@ -129,7 +127,7 @@ def _shutdown_torch(destroy_process_group=False):
def _set_torch_distributed_env_vars():
# Same env vars as in
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
from ray.air._internal.torch_utils import get_device
from ray.train.torch import get_device

context = ray.train.get_context()
os.environ["LOCAL_RANK"] = str(context.get_local_rank())
Expand All @@ -140,8 +138,6 @@ def _set_torch_distributed_env_vars():

# Makes sure Hugging Face Accelerate uses the correct device
device = get_device()
if isinstance(device, list):
device = device[0]
os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)


Expand Down
Loading
Loading