Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Estimate object store memory from in-flight tasks #42504

Merged
merged 13 commits into from
Jan 25, 2024
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ py_test(
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_runtime_metrics_scheduling",
size = "small",
srcs = ["tests/test_runtime_metrics_scheduling.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_size_estimation",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ class OpRuntimeMetrics:
obj_store_mem_freed: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)
# Current memory size in the object store.
obj_store_mem_cur: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)

# Current memory size in the object store from inputs.
obj_store_mem_inputs: int = field(default=0, metadata={"map_only": True})
# Current memory size in the object store from outputs.
obj_store_mem_outputs: int = field(default=0, metadata={"map_only": True})

# Peak memory size in the object store.
obj_store_mem_peak: int = field(default=0, metadata={"map_only": True})
# Spilled memory size in the object store.
Expand Down Expand Up @@ -201,13 +203,27 @@ def output_buffer_bytes(self) -> int:
"""Size in bytes of output blocks that are not taken by the downstream yet."""
return self.bytes_outputs_generated - self.bytes_outputs_taken

@property
def obj_store_mem_cur(self) -> int:
return self.obj_store_mem_inputs + self.obj_store_mem_outputs

@property
def obj_store_mem_cur_upper_bound(self) -> int:
if self.average_bytes_outputs_per_task is not None:
return self.obj_store_mem_inputs + max(
self.obj_store_mem_outputs,
self.num_tasks_running * self.average_bytes_outputs_per_task,
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
)
else:
return self.obj_store_mem_inputs + self.obj_store_mem_outputs

def on_input_received(self, input: RefBundle):
"""Callback when the operator receives a new input."""
self.num_inputs_received += 1
input_size = input.size_bytes()
self.bytes_inputs_received += input_size
# Update object store metrics.
self.obj_store_mem_cur += input_size
self.obj_store_mem_inputs += input_size
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

Expand All @@ -216,7 +232,7 @@ def on_output_taken(self, output: RefBundle):
output_bytes = output.size_bytes()
self.num_outputs_taken += 1
self.bytes_outputs_taken += output_bytes
self.obj_store_mem_cur -= output_bytes
self.obj_store_mem_outputs -= output_bytes

def on_task_submitted(self, task_index: int, inputs: RefBundle):
"""Callback when the operator submits a task."""
Expand All @@ -241,7 +257,7 @@ def on_output_generated(self, task_index: int, output: RefBundle):

# Update object store metrics.
self.obj_store_mem_alloc += output_bytes
self.obj_store_mem_cur += output_bytes
self.obj_store_mem_outputs += output_bytes
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

Expand Down Expand Up @@ -279,7 +295,7 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):
self.obj_store_mem_spilled += meta.size_bytes

self.obj_store_mem_freed += total_input_size
self.obj_store_mem_cur -= total_input_size
self.obj_store_mem_inputs -= total_input_size

inputs.destroy_if_owned()
del self._running_tasks[task_index]
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=self.metrics.obj_store_mem_cur,
object_store_memory=self.metrics.obj_store_mem_cur_upper_bound,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=self.metrics.obj_store_mem_cur,
object_store_memory=self.metrics.obj_store_mem_cur_upper_bound,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
3 changes: 0 additions & 3 deletions python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def _scheduling_loop_step(self, topology: Topology) -> bool:
self._report_current_usage(cur_usage, limits)
op = select_operator_to_run(
topology,
cur_usage,
limits,
self._backpressure_policies,
ensure_at_least_one_running=self._consumer_idling(),
Expand All @@ -285,10 +284,8 @@ def _scheduling_loop_step(self, topology: Topology) -> bool:
if DEBUG_TRACE_SCHEDULING:
_debug_dump_topology(topology)
topology[op].dispatch_next_task()
cur_usage = TopologyResourceUsage.of(topology)
op = select_operator_to_run(
topology,
cur_usage,
limits,
self._backpressure_policies,
ensure_at_least_one_running=self._consumer_idling(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def update_operator_states(topology: Topology) -> None:

def select_operator_to_run(
topology: Topology,
cur_usage: TopologyResourceUsage,
limits: ExecutionResources,
backpressure_policies: List[BackpressurePolicy],
ensure_at_least_one_running: bool,
Expand All @@ -544,11 +543,10 @@ def select_operator_to_run(
provides backpressure if the consumer is slow. However, once a bundle is returned
to the user, it is no longer tracked.
"""
assert isinstance(cur_usage, TopologyResourceUsage), cur_usage

# Filter to ops that are eligible for execution.
ops = []
for op, state in topology.items():
cur_usage = TopologyResourceUsage.of(topology)
under_resource_limits = _execution_allowed(op, cur_usage, limits)
if (
op.need_more_inputs()
Expand Down Expand Up @@ -714,11 +712,6 @@ def _execution_allowed(
cpu=global_limits.cpu, gpu=global_limits.gpu
)
global_ok_sans_memory = new_usage.satisfies_limit(global_limits_sans_memory)
downstream_usage = global_usage.downstream_memory_usage[op]
downstream_limit = global_limits.scale(downstream_usage.topology_fraction)
downstream_memory_ok = ExecutionResources(
object_store_memory=downstream_usage.object_store_memory
).satisfies_limit(downstream_limit)

# If completing a task decreases the overall object store memory usage, allow it
# even if we're over the global limit.
Expand All @@ -730,4 +723,4 @@ def _execution_allowed(
):
return True

return global_ok_sans_memory and downstream_memory_ok
return False
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 42 additions & 0 deletions python/ray/data/tests/test_runtime_metrics_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import time

import numpy as np
import pytest

import ray
from ray._private.internal_api import memory_summary
from ray.data._internal.execution.backpressure_policy import (
ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY,
ConcurrencyCapBackpressurePolicy,
)


def test_spam(shutdown_only, restore_data_context):
ctx = ray.init(object_store_memory=100 * 1024**2)

ray.data.DataContext.get_current().use_runtime_metrics_scheduling = True
ray.data.DataContext.get_current().set_config(
ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY, [ConcurrencyCapBackpressurePolicy]
)
ray.data.DataContext.get_current().set_config(
ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY, 1
)

def f(batch):
time.sleep(0.1)
return {"data": np.zeros(20 * 1024**2, dtype=np.uint8)}

ds = ray.data.range(10).repartition(10).materialize()
ds = ds.map_batches(f, batch_size=None)

for _ in ds.iter_batches(batch_size=None, batch_format="pyarrow"):
pass

meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading