From 1f5a10a805bfe90fa2b180cc37af153b3375c779 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 8 Dec 2023 09:45:01 -0800 Subject: [PATCH] [data] Improve stall detection for StreamingOutputsBackpressurePolicy (#41637) (#41720) When there is non-Data code running in the same clusters. Data StreamExecutor will consider all submitted tasks as active, while they may not actually have resources to run. https://github.com/ray-project/ray/pull/41603 is an attempt to fix the data+train workload by excluding training resources. While this PR is a more general fix for other workloads, with two main changes: 1. Besides detecting active tasks, we also detect if the downstream is not making any progress for a specific interval. 2. Introduce a new `reserved_resources` option to allow specifying non-Data resources. This PR along can also fix https://github.com/ray-project/ray/issues/41496 --------- Signed-off-by: Hao Chen Signed-off-by: Stephanie Wang Co-authored-by: Stephanie Wang --- .../ray/air/tests/test_new_dataset_config.py | 30 +++-- .../streaming_output_backpressure_policy.py | 89 +++++++++++-- .../execution/interfaces/execution_options.py | 23 ++++ .../execution/interfaces/executor.py | 1 + .../_internal/execution/streaming_executor.py | 28 +++-- .../data/tests/test_backpressure_policies.py | 119 +++++++++++++++++- .../ray/data/tests/test_streaming_executor.py | 61 ++++++++- python/ray/train/_internal/data_config.py | 22 ++-- 8 files changed, 320 insertions(+), 53 deletions(-) diff --git a/python/ray/air/tests/test_new_dataset_config.py b/python/ray/air/tests/test_new_dataset_config.py index a912a9a75125c..b4c7ea0917901 100644 --- a/python/ray/air/tests/test_new_dataset_config.py +++ b/python/ray/air/tests/test_new_dataset_config.py @@ -8,6 +8,7 @@ from ray.train import DataConfig, ScalingConfig from ray.data import DataIterator from ray.train.data_parallel_trainer import DataParallelTrainer +from ray.data._internal.execution.interfaces.execution_options import ExecutionOptions from ray.tests.conftest import * # noqa @@ -263,20 +264,18 @@ def test_materialized_preprocessing(ray_start_4_cpus): def test_data_config_default_resource_limits(shutdown_only): - """Test that DataConfig's default resource limits should exclude the resources - used by training.""" + """Test that DataConfig should exclude training resources from Data.""" cluster_cpus, cluster_gpus = 20, 10 num_workers = 2 # Resources used by training workers. cpus_per_worker, gpus_per_worker = 2, 1 # Resources used by the trainer actor. default_trainer_cpus, default_trainer_gpus = 1, 0 - expected_cpu_limit = ( - cluster_cpus - num_workers * cpus_per_worker - default_trainer_cpus - ) - expected_gpu_limit = ( - cluster_gpus - num_workers * gpus_per_worker - default_trainer_gpus - ) + num_train_cpus = num_workers * cpus_per_worker + default_trainer_cpus + num_train_gpus = num_workers * gpus_per_worker + default_trainer_gpus + + init_exclude_cpus = 2 + init_exclude_gpus = 1 ray.init(num_cpus=cluster_cpus, num_gpus=cluster_gpus) @@ -284,13 +283,18 @@ class MyTrainer(DataParallelTrainer): def __init__(self, **kwargs): def train_loop_fn(): train_ds = train.get_dataset_shard("train") - resource_limits = ( - train_ds._base_dataset.context.execution_options.resource_limits + exclude_resources = ( + train_ds._base_dataset.context.execution_options.exclude_resources ) - assert resource_limits.cpu == expected_cpu_limit - assert resource_limits.gpu == expected_gpu_limit + assert exclude_resources.cpu == num_train_cpus + init_exclude_cpus + assert exclude_resources.gpu == num_train_gpus + init_exclude_gpus kwargs.pop("scaling_config", None) + + execution_options = ExecutionOptions() + execution_options.exclude_resources.cpu = init_exclude_cpus + execution_options.exclude_resources.gpu = init_exclude_gpus + super().__init__( train_loop_per_worker=train_loop_fn, scaling_config=ScalingConfig( @@ -302,7 +306,7 @@ def train_loop_fn(): }, ), datasets={"train": ray.data.range(10)}, - dataset_config=DataConfig(), + dataset_config=DataConfig(execution_options=execution_options), **kwargs, ) diff --git a/python/ray/data/_internal/execution/backpressure_policy/streaming_output_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/streaming_output_backpressure_policy.py index 91990fd9a5028..ba1ad2782c6ac 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/streaming_output_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/streaming_output_backpressure_policy.py @@ -1,12 +1,19 @@ -from typing import TYPE_CHECKING, Dict +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple import ray from .backpressure_policy import BackpressurePolicy +from ray.data._internal.dataset_logger import DatasetLogger if TYPE_CHECKING: + from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.streaming_executor_state import OpState, Topology +logger = DatasetLogger(__name__) + + class StreamingOutputBackpressurePolicy(BackpressurePolicy): """A backpressure policy that throttles the streaming outputs of the `DataOpTask`s. @@ -39,6 +46,10 @@ class StreamingOutputBackpressurePolicy(BackpressurePolicy): "backpressure_policies.streaming_output.max_blocks_in_op_output_queue" ) + # If an operator has active tasks but no outputs for at least this time, + # we'll consider it as idle and temporarily unblock backpressure for its upstream. + MAX_OUTPUT_IDLE_SECONDS = 10 + def __init__(self, topology: "Topology"): data_context = ray.data.DataContext.get_current() self._max_num_blocks_in_streaming_gen_buffer = data_context.get_config( @@ -59,28 +70,80 @@ def __init__(self, topology: "Topology"): ) assert self._max_num_blocks_in_op_output_queue > 0 + # Latest number of outputs and the last time when the number changed + # for each op. + self._last_num_outputs_and_time: Dict[ + "PhysicalOperator", Tuple[int, float] + ] = defaultdict(lambda: (0, time.time())) + self._warning_printed = False + def calculate_max_blocks_to_read_per_op( self, topology: "Topology" ) -> Dict["OpState", int]: max_blocks_to_read_per_op: Dict["OpState", int] = {} - downstream_num_active_tasks = 0 + + # Indicates if the immediate downstream operator is idle. + downstream_idle = False + for op, state in reversed(topology.items()): max_blocks_to_read_per_op[state] = ( self._max_num_blocks_in_op_output_queue - state.outqueue_num_blocks() ) - if downstream_num_active_tasks == 0: - # If all downstream operators are idle, it could be because no resources - # are available. In this case, we'll make sure to read at least one - # block to avoid deadlock. - # TODO(hchen): `downstream_num_active_tasks == 0` doesn't necessarily - # mean no enough resources. One false positive case is when the upstream - # op hasn't produced any blocks for the downstream op to consume. - # In this case, at least reading one block is fine. - # If there are other false positive cases, we may want to make this - # deadlock check more accurate by directly checking resources. + + if downstream_idle: max_blocks_to_read_per_op[state] = max( max_blocks_to_read_per_op[state], 1, ) - downstream_num_active_tasks += len(op.get_active_tasks()) + + # An operator is considered idle if either of the following is true: + # - It has no active tasks. + # - This can happen when all resources are used by upstream operators. + # - It has active tasks, but no outputs for at least + # `MAX_OUTPUT_IDLE_SECONDS`. + # - This can happen when non-Data code preempted cluster resources, and + # - some of the active tasks don't actually have enough resources to run. + # + # If the operator is idle, we'll temporarily unblock backpressure by + # allowing reading at least one block from its upstream + # to avoid deadlock. + # NOTE, these 2 conditions don't necessarily mean deadlock. + # The first case can also happen when the upstream operator hasn't outputted + # any blocks yet. While the second case can also happen when the task is + # expected to output data slowly. + # The false postive cases are fine as we only allow reading one block + # each time. + downstream_idle = False + if op.num_active_tasks() == 0: + downstream_idle = True + else: + cur_num_outputs = state.op.metrics.num_outputs_generated + cur_time = time.time() + last_num_outputs, last_time = self._last_num_outputs_and_time[state.op] + if cur_num_outputs > last_num_outputs: + self._last_num_outputs_and_time[state.op] = ( + cur_num_outputs, + cur_time, + ) + else: + if cur_time - last_time > self.MAX_OUTPUT_IDLE_SECONDS: + downstream_idle = True + self._print_warning(state.op, cur_time - last_time) return max_blocks_to_read_per_op + + def _print_warning(self, op: "PhysicalOperator", idle_time: float): + if self._warning_printed: + return + self._warning_printed = True + msg = ( + f"Operator {op} is running but has no outputs for {idle_time} seconds." + " Execution may be slower than expected.\n" + "Ignore this warning if your UDF is expected to be slow." + " Otherwise, this can happen when there are fewer cluster resources" + " available to Ray Data than expected." + " If you have non-Data tasks or actors running in the cluster, exclude" + " their resources from Ray Data with" + " `DataContext.get_current().execution_options.exclude_resources`." + " This message will only print once." + ) + logger.get_logger().warning(msg) diff --git a/python/ray/data/_internal/execution/interfaces/execution_options.py b/python/ray/data/_internal/execution/interfaces/execution_options.py index 4b09843bc01da..561a7ad0e732f 100644 --- a/python/ray/data/_internal/execution/interfaces/execution_options.py +++ b/python/ray/data/_internal/execution/interfaces/execution_options.py @@ -89,6 +89,13 @@ class ExecutionOptions: Attributes: resource_limits: Set a soft limit on the resource usage during execution. This is not supported in bulk execution mode. Autodetected by default. + exclude_resources: Amount of resources to exclude from Ray Data. + Set this if you have other workloads running on the same cluster. + Note, + - If using Ray Data with Ray Train, training resources will be + automatically excluded. + - For each resource type, resource_limits and exclude_resources can + not be both set. locality_with_output: Set this to prefer running tasks on the same node as the output node (node driving the execution). It can also be set to a list of node ids to spread the outputs across those nodes. Off by default. @@ -105,6 +112,10 @@ class ExecutionOptions: resource_limits: ExecutionResources = field(default_factory=ExecutionResources) + exclude_resources: ExecutionResources = field( + default_factory=lambda: ExecutionResources(cpu=0, gpu=0, object_store_memory=0) + ) + locality_with_output: Union[bool, List[NodeIdStr]] = False preserve_order: bool = False @@ -112,3 +123,15 @@ class ExecutionOptions: actor_locality_enabled: bool = True verbose_progress: bool = bool(int(os.environ.get("RAY_DATA_VERBOSE_PROGRESS", "0"))) + + def validate(self) -> None: + """Validate the options.""" + for attr in ["cpu", "gpu", "object_store_memory"]: + if ( + getattr(self.resource_limits, attr) is not None + and getattr(self.exclude_resources, attr, 0) > 0 + ): + raise ValueError( + "resource_limits and exclude_resources cannot " + f" both be set for {attr} resource." + ) diff --git a/python/ray/data/_internal/execution/interfaces/executor.py b/python/ray/data/_internal/execution/interfaces/executor.py index 3489ba26de618..007346b60f294 100644 --- a/python/ray/data/_internal/execution/interfaces/executor.py +++ b/python/ray/data/_internal/execution/interfaces/executor.py @@ -46,6 +46,7 @@ class Executor: def __init__(self, options: ExecutionOptions): """Create the executor.""" + options.validate() self._options = options def execute( diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 40312a36f850e..f4b4e51f1acc7 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -73,6 +73,7 @@ def __init__(self, options: ExecutionOptions, dataset_tag: str = "unknown_datase # The executor can be shutdown while still running. self._shutdown_lock = threading.RLock() + self._execution_started = False self._shutdown = False # Internal execution state shared across thread boundaries. We run the control @@ -133,6 +134,7 @@ def execute( self._get_operator_tags(), ) self.start() + self._execution_started = True class StreamIterator(OutputIterator): def __init__(self, outer: Executor): @@ -165,7 +167,7 @@ def shutdown(self, execution_completed: bool = True): global _num_shutdown with self._shutdown_lock: - if self._shutdown: + if not self._execution_started or self._shutdown: return logger.get_logger().debug(f"Shutting down {self}.") _num_shutdown += 1 @@ -332,16 +334,26 @@ def _get_or_refresh_resource_limits(self) -> ExecutionResources: autoscaling. """ base = self._options.resource_limits + exclude = self._options.exclude_resources cluster = ray.cluster_resources() - return ExecutionResources( - cpu=base.cpu if base.cpu is not None else cluster.get("CPU", 0.0), - gpu=base.gpu if base.gpu is not None else cluster.get("GPU", 0.0), - object_store_memory=base.object_store_memory - if base.object_store_memory is not None - else round( + + cpu = base.cpu + if cpu is None: + cpu = cluster.get("CPU", 0.0) - (exclude.cpu or 0.0) + gpu = base.gpu + if gpu is None: + gpu = cluster.get("GPU", 0.0) - (exclude.gpu or 0.0) + object_store_memory = base.object_store_memory + if object_store_memory is None: + object_store_memory = round( DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION * cluster.get("object_store_memory", 0.0) - ), + ) - (exclude.object_store_memory or 0) + + return ExecutionResources( + cpu=cpu, + gpu=gpu, + object_store_memory=object_store_memory, ) def _report_current_usage( diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 669165ceb1dd8..033df5a3e6c45 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -3,7 +3,7 @@ import unittest from collections import defaultdict from contextlib import contextmanager -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np @@ -217,10 +217,16 @@ def setUpClass(cls): cls._num_blocks = 5 cls._block_size = 100 * 1024 * 1024 policy_cls = StreamingOutputBackpressurePolicy + cls._max_blocks_in_op_output_queue = 1 + cls._max_blocks_in_generator_buffer = 1 cls._configs = { ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY: [policy_cls], - policy_cls.MAX_BLOCKS_IN_OP_OUTPUT_QUEUE_CONFIG_KEY: 1, - policy_cls.MAX_BLOCKS_IN_GENERATOR_BUFFER_CONFIG_KEY: 1, + policy_cls.MAX_BLOCKS_IN_OP_OUTPUT_QUEUE_CONFIG_KEY: ( + cls._max_blocks_in_op_output_queue + ), + policy_cls.MAX_BLOCKS_IN_GENERATOR_BUFFER_CONFIG_KEY: ( + cls._max_blocks_in_generator_buffer + ), } for k, v in cls._configs.items(): data_context.set_config(k, v) @@ -234,6 +240,87 @@ def tearDownClass(cls): data_context.execution_options.preserve_order = False ray.shutdown() + def _create_mock_op_and_op_state( + self, + name, + outqueue_num_blocks=0, + num_active_tasks=0, + num_outputs_generated=0, + ): + op = MagicMock() + op.__str__.return_value = f"Op({name})" + op.num_active_tasks.return_value = num_active_tasks + op.metrics.num_outputs_generated = num_outputs_generated + + state = MagicMock() + state.__str__.return_value = f"OpState({name})" + state.outqueue_num_blocks.return_value = outqueue_num_blocks + + state.op = op + return op, state + + def test_policy_basic(self): + """Basic unit test for the policy without real execution.""" + up_op, up_state = self._create_mock_op_and_op_state("up") + down_op, down_state = self._create_mock_op_and_op_state("down") + topology = {} + topology[up_op] = up_state + topology[down_op] = down_state + + policy = StreamingOutputBackpressurePolicy(topology) + assert ( + policy._max_num_blocks_in_op_output_queue + == self._max_blocks_in_op_output_queue + ) + assert ( + policy._max_num_blocks_in_streaming_gen_buffer + == self._max_blocks_in_generator_buffer + ) + + # Buffers are empty, both ops can read up to the max. + res = policy.calculate_max_blocks_to_read_per_op(topology) + assert res == { + up_state: self._max_blocks_in_op_output_queue, + down_state: self._max_blocks_in_op_output_queue, + } + + # up_op's buffer is full, but down_up has no active tasks. + # We'll still allow up_op to read 1 block. + up_state.outqueue_num_blocks.return_value = self._max_blocks_in_op_output_queue + res = policy.calculate_max_blocks_to_read_per_op(topology) + assert res == { + up_state: 1, + down_state: self._max_blocks_in_op_output_queue, + } + + # down_op now has 1 active task. So we won't allow up_op to read any more. + down_op.num_active_tasks.return_value = 1 + res = policy.calculate_max_blocks_to_read_per_op(topology) + assert res == { + up_state: 0, + down_state: self._max_blocks_in_op_output_queue, + } + + # After `MAX_OUTPUT_IDLE_SECONDS` of no outputs from down_up, + # we'll allow up_op to read 1 block again. + with patch.object( + StreamingOutputBackpressurePolicy, "MAX_OUTPUT_IDLE_SECONDS", 0.1 + ): + time.sleep(0.11) + res = policy.calculate_max_blocks_to_read_per_op(topology) + assert res == { + up_state: 1, + down_state: self._max_blocks_in_op_output_queue, + } + + # down_up now has outputs, so we won't allow up_op to read any more. + down_op.metrics.num_outputs_generated = 1 + res = policy.calculate_max_blocks_to_read_per_op(topology) + assert res == { + up_state: 0, + down_state: self._max_blocks_in_op_output_queue, + } + def _run_dataset(self, producer_num_cpus, consumer_num_cpus): # Create a dataset with 2 operators: # - The producer op has only 1 task, which produces 5 blocks, each of which @@ -272,7 +359,7 @@ def consumer(batch): [row["consumer_timestamp"] for row in res], ) - def test_basic_backpressure(self): + def test_e2e_backpressure(self): producer_timestamps, consumer_timestamps = self._run_dataset( producer_num_cpus=1, consumer_num_cpus=2 ) @@ -297,6 +384,30 @@ def test_no_deadlock(self): consumer_timestamps, ) + def test_no_deadlock_for_resource_contention(self): + """Test no deadlock in case of resource contention from + non-Data code.""" + # Create a non-Data actor that uses 4 CPUs, only 1 CPU + # is left for Data. Currently Data StreamExecutor still + # incorrectly assumes it has all the 5 CPUs. + # Check that we don't deadlock in this case. + + @ray.remote(num_cpus=4) + class DummyActor: + def foo(self): + return None + + dummy_actor = DummyActor.remote() + ray.get(dummy_actor.foo.remote()) + + producer_timestamps, consumer_timestamps = self._run_dataset( + producer_num_cpus=1, consumer_num_cpus=0.9 + ) + assert producer_timestamps[-1] < consumer_timestamps[0], ( + producer_timestamps, + consumer_timestamps, + ) + def test_large_e2e_backpressure(shutdown_only, restore_data_context): # noqa: F811 """Test backpressure on a synthetic large-scale workload.""" diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index 38f7b467daf2e..17a3a4d8a7d52 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -1,6 +1,7 @@ import collections +import math import time -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -18,10 +19,12 @@ create_map_transformer_from_block_fn, ) from ray.data._internal.execution.streaming_executor import ( + StreamingExecutor, _debug_dump_topology, _validate_dag, ) from ray.data._internal.execution.streaming_executor_state import ( + DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION, AutoscalingState, DownstreamMemoryInfo, OpState, @@ -698,6 +701,62 @@ def test_execution_allowed_nothrottle(): ) +def test_resource_limits(): + cluster_resources = {"CPU": 10, "GPU": 5, "object_store_memory": 1000} + default_object_store_memory_limit = math.ceil( + cluster_resources["object_store_memory"] + * DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION + ) + + with patch("ray.cluster_resources", return_value=cluster_resources): + # Test default resource limits. + # When no resource limits are set, the resource limits should default to + # the cluster resources for CPU/GPU, and + # DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION of cluster object store memory. + options = ExecutionOptions() + executor = StreamingExecutor(options, "") + expected = ExecutionResources( + cpu=cluster_resources["CPU"], + gpu=cluster_resources["GPU"], + object_store_memory=default_object_store_memory_limit, + ) + assert executor._get_or_refresh_resource_limits() == expected + + # Test setting resource_limits + options = ExecutionOptions() + options.resource_limits = ExecutionResources( + cpu=1, gpu=2, object_store_memory=100 + ) + executor = StreamingExecutor(options, "") + expected = ExecutionResources( + cpu=1, + gpu=2, + object_store_memory=100, + ) + assert executor._get_or_refresh_resource_limits() == expected + + # Test setting exclude_resources + # The actual limit should be the default limit minus the excluded resources. + options = ExecutionOptions() + options.exclude_resources = ExecutionResources( + cpu=1, gpu=2, object_store_memory=100 + ) + executor = StreamingExecutor(options, "") + expected = ExecutionResources( + cpu=cluster_resources["CPU"] - 1, + gpu=cluster_resources["GPU"] - 2, + object_store_memory=default_object_store_memory_limit - 100, + ) + assert executor._get_or_refresh_resource_limits() == expected + + # Test that we don't support setting both resource_limits and exclude_resources. + with pytest.raises(ValueError): + options = ExecutionOptions() + options.resource_limits = ExecutionResources(cpu=2) + options.exclude_resources = ExecutionResources(cpu=1) + options.validate() + + @pytest.mark.parametrize( "max_errored_blocks, num_errored_blocks", [ diff --git a/python/ray/train/_internal/data_config.py b/python/ray/train/_internal/data_config.py index e533e9c8ee03e..fb85dce793ea2 100644 --- a/python/ray/train/_internal/data_config.py +++ b/python/ray/train/_internal/data_config.py @@ -4,6 +4,7 @@ import ray from ray.actor import ActorHandle from ray.data import DataIterator, Dataset, ExecutionOptions, NodeIdStr +from ray.data._internal.execution.interfaces.execution_options import ExecutionResources from ray.data.preprocessor import Preprocessor # TODO(justinvyu): Fix the circular import error @@ -94,21 +95,14 @@ def configure( ds = ds.copy(ds) ds.context.execution_options = copy.deepcopy(self._execution_options) - # If CPU or GPU resource limits are not set, - # exclude the resources used by training from the resource limits. - # TODO(hchen): We calculate the resource limits based on the current - # cluster resources here, which means that auto-scaling is not supported. - # This should be fixed when we want to support auto-scaling for Ray Train. - resource_limits = ds.context.execution_options.resource_limits - cluster_resources = ray.cluster_resources() - if resource_limits.cpu is None: - resource_limits.cpu = ( - cluster_resources.get("CPU", 0) - self._num_train_cpus - ) - if resource_limits.gpu is None: - resource_limits.gpu = ( - cluster_resources.get("GPU", 0) - self._num_train_gpus + # Add training-reserved resources to Data's exclude_resources. + ds.context.execution_options.exclude_resources = ( + ds.context.execution_options.exclude_resources.add( + ExecutionResources( + cpu=self._num_train_cpus, gpu=self._num_train_gpus + ) ) + ) if name in datasets_to_split: for i, split in enumerate(