Skip to content

Commit

Permalink
[data][train] Fix locality config not being respected in DataConfig (r…
Browse files Browse the repository at this point in the history
…ay-project#42204)

Fix the bug that the locality config in DataConfig is not respected. This bug makes locality always enabled for training ingest workloads.

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
  • Loading branch information
raulchen committed Jan 26, 2024
1 parent cc2646f commit 9a3d525
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
27 changes: 27 additions & 0 deletions python/ray/air/tests/test_new_dataset_config.py
Expand Up @@ -2,6 +2,7 @@

import random
import pytest
from unittest.mock import MagicMock

import ray
from ray import train
Expand Down Expand Up @@ -168,6 +169,32 @@ def test_configure_execution_options_carryover_context(ray_start_4_cpus):
assert ingest_options.verbose_progress is True


@pytest.mark.parametrize("enable_locality", [True, False])
def test_configure_locality(enable_locality):
options = DataConfig.default_ingest_options()
options.locality_with_output = enable_locality
data_config = DataConfig(execution_options=options)

mock_ds = MagicMock()
mock_ds.streaming_split = MagicMock()
mock_ds.copy = MagicMock(return_value=mock_ds)
world_size = 2
worker_handles = [MagicMock() for _ in range(world_size)]
worker_node_ids = ["node" + str(i) for i in range(world_size)]
data_config.configure(
datasets={"train": mock_ds},
world_size=world_size,
worker_handles=worker_handles,
worker_node_ids=worker_node_ids,
)
mock_ds.streaming_split.assert_called_once()
mock_ds.streaming_split.assert_called_with(
world_size,
equal=True,
locality_hints=worker_node_ids if enable_locality else None,
)


class CustomConfig(DataConfig):
def __init__(self):
pass
Expand Down
5 changes: 4 additions & 1 deletion python/ray/train/_internal/data_config.py
Expand Up @@ -91,6 +91,9 @@ def configure(
else:
datasets_to_split = set(self._datasets_to_split)

locality_hints = (
worker_node_ids if self._execution_options.locality_with_output else None
)
for name, ds in datasets.items():
ds = ds.copy(ds)
ds.context.execution_options = copy.deepcopy(self._execution_options)
Expand All @@ -107,7 +110,7 @@ def configure(
if name in datasets_to_split:
for i, split in enumerate(
ds.streaming_split(
world_size, equal=True, locality_hints=worker_node_ids
world_size, equal=True, locality_hints=locality_hints
)
):
output[i][name] = split
Expand Down

0 comments on commit 9a3d525

Please sign in to comment.