Skip to content

Commit

Permalink
[data] Improve stall detection for StreamingOutputsBackpressurePolicy (
Browse files Browse the repository at this point in the history
…#41637) (#41720)

When there is non-Data code running in the same clusters. Data StreamExecutor will consider all submitted tasks as active, while they may not actually have resources to run.
#41603 is an attempt to fix the data+train workload by excluding training resources.

While this PR is a more general fix for other workloads, with two main changes:
1. Besides detecting active tasks, we also detect if the downstream is not making any progress for a specific interval.
2. Introduce a new `reserved_resources` option to allow specifying non-Data resources.

This PR along can also fix #41496
---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
raulchen and stephanie-wang committed Dec 8, 2023
1 parent 2659bea commit 1f5a10a
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 53 deletions.
30 changes: 17 additions & 13 deletions python/ray/air/tests/test_new_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.train import DataConfig, ScalingConfig
from ray.data import DataIterator
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.data._internal.execution.interfaces.execution_options import ExecutionOptions
from ray.tests.conftest import * # noqa


Expand Down Expand Up @@ -263,34 +264,37 @@ def test_materialized_preprocessing(ray_start_4_cpus):


def test_data_config_default_resource_limits(shutdown_only):
"""Test that DataConfig's default resource limits should exclude the resources
used by training."""
"""Test that DataConfig should exclude training resources from Data."""
cluster_cpus, cluster_gpus = 20, 10
num_workers = 2
# Resources used by training workers.
cpus_per_worker, gpus_per_worker = 2, 1
# Resources used by the trainer actor.
default_trainer_cpus, default_trainer_gpus = 1, 0
expected_cpu_limit = (
cluster_cpus - num_workers * cpus_per_worker - default_trainer_cpus
)
expected_gpu_limit = (
cluster_gpus - num_workers * gpus_per_worker - default_trainer_gpus
)
num_train_cpus = num_workers * cpus_per_worker + default_trainer_cpus
num_train_gpus = num_workers * gpus_per_worker + default_trainer_gpus

init_exclude_cpus = 2
init_exclude_gpus = 1

ray.init(num_cpus=cluster_cpus, num_gpus=cluster_gpus)

class MyTrainer(DataParallelTrainer):
def __init__(self, **kwargs):
def train_loop_fn():
train_ds = train.get_dataset_shard("train")
resource_limits = (
train_ds._base_dataset.context.execution_options.resource_limits
exclude_resources = (
train_ds._base_dataset.context.execution_options.exclude_resources
)
assert resource_limits.cpu == expected_cpu_limit
assert resource_limits.gpu == expected_gpu_limit
assert exclude_resources.cpu == num_train_cpus + init_exclude_cpus
assert exclude_resources.gpu == num_train_gpus + init_exclude_gpus

kwargs.pop("scaling_config", None)

execution_options = ExecutionOptions()
execution_options.exclude_resources.cpu = init_exclude_cpus
execution_options.exclude_resources.gpu = init_exclude_gpus

super().__init__(
train_loop_per_worker=train_loop_fn,
scaling_config=ScalingConfig(
Expand All @@ -302,7 +306,7 @@ def train_loop_fn():
},
),
datasets={"train": ray.data.range(10)},
dataset_config=DataConfig(),
dataset_config=DataConfig(execution_options=execution_options),
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import TYPE_CHECKING, Dict
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Tuple

import ray
from .backpressure_policy import BackpressurePolicy
from ray.data._internal.dataset_logger import DatasetLogger

if TYPE_CHECKING:
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.streaming_executor_state import OpState, Topology


logger = DatasetLogger(__name__)


class StreamingOutputBackpressurePolicy(BackpressurePolicy):
"""A backpressure policy that throttles the streaming outputs of the `DataOpTask`s.
Expand Down Expand Up @@ -39,6 +46,10 @@ class StreamingOutputBackpressurePolicy(BackpressurePolicy):
"backpressure_policies.streaming_output.max_blocks_in_op_output_queue"
)

# If an operator has active tasks but no outputs for at least this time,
# we'll consider it as idle and temporarily unblock backpressure for its upstream.
MAX_OUTPUT_IDLE_SECONDS = 10

def __init__(self, topology: "Topology"):
data_context = ray.data.DataContext.get_current()
self._max_num_blocks_in_streaming_gen_buffer = data_context.get_config(
Expand All @@ -59,28 +70,80 @@ def __init__(self, topology: "Topology"):
)
assert self._max_num_blocks_in_op_output_queue > 0

# Latest number of outputs and the last time when the number changed
# for each op.
self._last_num_outputs_and_time: Dict[
"PhysicalOperator", Tuple[int, float]
] = defaultdict(lambda: (0, time.time()))
self._warning_printed = False

def calculate_max_blocks_to_read_per_op(
self, topology: "Topology"
) -> Dict["OpState", int]:
max_blocks_to_read_per_op: Dict["OpState", int] = {}
downstream_num_active_tasks = 0

# Indicates if the immediate downstream operator is idle.
downstream_idle = False

for op, state in reversed(topology.items()):
max_blocks_to_read_per_op[state] = (
self._max_num_blocks_in_op_output_queue - state.outqueue_num_blocks()
)
if downstream_num_active_tasks == 0:
# If all downstream operators are idle, it could be because no resources
# are available. In this case, we'll make sure to read at least one
# block to avoid deadlock.
# TODO(hchen): `downstream_num_active_tasks == 0` doesn't necessarily
# mean no enough resources. One false positive case is when the upstream
# op hasn't produced any blocks for the downstream op to consume.
# In this case, at least reading one block is fine.
# If there are other false positive cases, we may want to make this
# deadlock check more accurate by directly checking resources.

if downstream_idle:
max_blocks_to_read_per_op[state] = max(
max_blocks_to_read_per_op[state],
1,
)
downstream_num_active_tasks += len(op.get_active_tasks())

# An operator is considered idle if either of the following is true:
# - It has no active tasks.
# - This can happen when all resources are used by upstream operators.
# - It has active tasks, but no outputs for at least
# `MAX_OUTPUT_IDLE_SECONDS`.
# - This can happen when non-Data code preempted cluster resources, and
# - some of the active tasks don't actually have enough resources to run.
#
# If the operator is idle, we'll temporarily unblock backpressure by
# allowing reading at least one block from its upstream
# to avoid deadlock.
# NOTE, these 2 conditions don't necessarily mean deadlock.
# The first case can also happen when the upstream operator hasn't outputted
# any blocks yet. While the second case can also happen when the task is
# expected to output data slowly.
# The false postive cases are fine as we only allow reading one block
# each time.
downstream_idle = False
if op.num_active_tasks() == 0:
downstream_idle = True
else:
cur_num_outputs = state.op.metrics.num_outputs_generated
cur_time = time.time()
last_num_outputs, last_time = self._last_num_outputs_and_time[state.op]
if cur_num_outputs > last_num_outputs:
self._last_num_outputs_and_time[state.op] = (
cur_num_outputs,
cur_time,
)
else:
if cur_time - last_time > self.MAX_OUTPUT_IDLE_SECONDS:
downstream_idle = True
self._print_warning(state.op, cur_time - last_time)
return max_blocks_to_read_per_op

def _print_warning(self, op: "PhysicalOperator", idle_time: float):
if self._warning_printed:
return
self._warning_printed = True
msg = (
f"Operator {op} is running but has no outputs for {idle_time} seconds."
" Execution may be slower than expected.\n"
"Ignore this warning if your UDF is expected to be slow."
" Otherwise, this can happen when there are fewer cluster resources"
" available to Ray Data than expected."
" If you have non-Data tasks or actors running in the cluster, exclude"
" their resources from Ray Data with"
" `DataContext.get_current().execution_options.exclude_resources`."
" This message will only print once."
)
logger.get_logger().warning(msg)
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class ExecutionOptions:
Attributes:
resource_limits: Set a soft limit on the resource usage during execution.
This is not supported in bulk execution mode. Autodetected by default.
exclude_resources: Amount of resources to exclude from Ray Data.
Set this if you have other workloads running on the same cluster.
Note,
- If using Ray Data with Ray Train, training resources will be
automatically excluded.
- For each resource type, resource_limits and exclude_resources can
not be both set.
locality_with_output: Set this to prefer running tasks on the same node as the
output node (node driving the execution). It can also be set to a list of
node ids to spread the outputs across those nodes. Off by default.
Expand All @@ -105,10 +112,26 @@ class ExecutionOptions:

resource_limits: ExecutionResources = field(default_factory=ExecutionResources)

exclude_resources: ExecutionResources = field(
default_factory=lambda: ExecutionResources(cpu=0, gpu=0, object_store_memory=0)
)

locality_with_output: Union[bool, List[NodeIdStr]] = False

preserve_order: bool = False

actor_locality_enabled: bool = True

verbose_progress: bool = bool(int(os.environ.get("RAY_DATA_VERBOSE_PROGRESS", "0")))

def validate(self) -> None:
"""Validate the options."""
for attr in ["cpu", "gpu", "object_store_memory"]:
if (
getattr(self.resource_limits, attr) is not None
and getattr(self.exclude_resources, attr, 0) > 0
):
raise ValueError(
"resource_limits and exclude_resources cannot "
f" both be set for {attr} resource."
)
1 change: 1 addition & 0 deletions python/ray/data/_internal/execution/interfaces/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Executor:

def __init__(self, options: ExecutionOptions):
"""Create the executor."""
options.validate()
self._options = options

def execute(
Expand Down
28 changes: 20 additions & 8 deletions python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, options: ExecutionOptions, dataset_tag: str = "unknown_datase

# The executor can be shutdown while still running.
self._shutdown_lock = threading.RLock()
self._execution_started = False
self._shutdown = False

# Internal execution state shared across thread boundaries. We run the control
Expand Down Expand Up @@ -133,6 +134,7 @@ def execute(
self._get_operator_tags(),
)
self.start()
self._execution_started = True

class StreamIterator(OutputIterator):
def __init__(self, outer: Executor):
Expand Down Expand Up @@ -165,7 +167,7 @@ def shutdown(self, execution_completed: bool = True):
global _num_shutdown

with self._shutdown_lock:
if self._shutdown:
if not self._execution_started or self._shutdown:
return
logger.get_logger().debug(f"Shutting down {self}.")
_num_shutdown += 1
Expand Down Expand Up @@ -332,16 +334,26 @@ def _get_or_refresh_resource_limits(self) -> ExecutionResources:
autoscaling.
"""
base = self._options.resource_limits
exclude = self._options.exclude_resources
cluster = ray.cluster_resources()
return ExecutionResources(
cpu=base.cpu if base.cpu is not None else cluster.get("CPU", 0.0),
gpu=base.gpu if base.gpu is not None else cluster.get("GPU", 0.0),
object_store_memory=base.object_store_memory
if base.object_store_memory is not None
else round(

cpu = base.cpu
if cpu is None:
cpu = cluster.get("CPU", 0.0) - (exclude.cpu or 0.0)
gpu = base.gpu
if gpu is None:
gpu = cluster.get("GPU", 0.0) - (exclude.gpu or 0.0)
object_store_memory = base.object_store_memory
if object_store_memory is None:
object_store_memory = round(
DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION
* cluster.get("object_store_memory", 0.0)
),
) - (exclude.object_store_memory or 0)

return ExecutionResources(
cpu=cpu,
gpu=gpu,
object_store_memory=object_store_memory,
)

def _report_current_usage(
Expand Down
Loading

0 comments on commit 1f5a10a

Please sign in to comment.