Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Estimate object store memory from in-flight tasks #42504

Merged
merged 13 commits into from
Jan 25, 2024
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ py_test(
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_runtime_metrics_scheduling",
size = "small",
srcs = ["tests/test_runtime_metrics_scheduling.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_size_estimation",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ class OpRuntimeMetrics:
obj_store_mem_freed: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)
# Current memory size in the object store.
obj_store_mem_cur: int = field(
default=0, metadata={"map_only": True, "export_metric": True}
)

# Current memory size in the object store from inputs.
obj_store_mem_inputs: int = field(default=0, metadata={"map_only": True})
# Current memory size in the object store from outputs.
obj_store_mem_outputs: int = field(default=0, metadata={"map_only": True})

# Peak memory size in the object store.
obj_store_mem_peak: int = field(default=0, metadata={"map_only": True})
# Spilled memory size in the object store.
Expand Down Expand Up @@ -201,13 +203,27 @@ def output_buffer_bytes(self) -> int:
"""Size in bytes of output blocks that are not taken by the downstream yet."""
return self.bytes_outputs_generated - self.bytes_outputs_taken

@property
def obj_store_mem_cur(self) -> int:
return self.obj_store_mem_inputs + self.obj_store_mem_outputs

@property
def obj_store_mem_cur_upper_bound(self) -> int:
if self.average_bytes_outputs_per_task is not None:
return self.obj_store_mem_inputs + max(
self.obj_store_mem_outputs,
self.num_tasks_running * self.average_bytes_outputs_per_task,
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
)
else:
return self.obj_store_mem_inputs + self.obj_store_mem_outputs

def on_input_received(self, input: RefBundle):
"""Callback when the operator receives a new input."""
self.num_inputs_received += 1
input_size = input.size_bytes()
self.bytes_inputs_received += input_size
# Update object store metrics.
self.obj_store_mem_cur += input_size
self.obj_store_mem_inputs += input_size
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

Expand All @@ -216,7 +232,7 @@ def on_output_taken(self, output: RefBundle):
output_bytes = output.size_bytes()
self.num_outputs_taken += 1
self.bytes_outputs_taken += output_bytes
self.obj_store_mem_cur -= output_bytes
self.obj_store_mem_outputs -= output_bytes

def on_task_submitted(self, task_index: int, inputs: RefBundle):
"""Callback when the operator submits a task."""
Expand All @@ -241,7 +257,7 @@ def on_output_generated(self, task_index: int, output: RefBundle):

# Update object store metrics.
self.obj_store_mem_alloc += output_bytes
self.obj_store_mem_cur += output_bytes
self.obj_store_mem_outputs += output_bytes
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

Expand Down Expand Up @@ -279,7 +295,7 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):
self.obj_store_mem_spilled += meta.size_bytes

self.obj_store_mem_freed += total_input_size
self.obj_store_mem_cur -= total_input_size
self.obj_store_mem_inputs -= total_input_size

inputs.destroy_if_owned()
del self._running_tasks[task_index]
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=self.metrics.obj_store_mem_cur,
object_store_memory=self.metrics.obj_store_mem_cur_upper_bound,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=self.metrics.obj_store_mem_cur,
object_store_memory=self.metrics.obj_store_mem_cur_upper_bound,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,9 +715,15 @@ def _execution_allowed(
)
global_ok_sans_memory = new_usage.satisfies_limit(global_limits_sans_memory)
downstream_usage = global_usage.downstream_memory_usage[op]
downstream_memory = downstream_usage.object_store_memory
if (
DataContext.get_current().use_runtime_metrics_scheduling
and inc.object_store_memory
):
downstream_memory += inc.object_store_memory
downstream_limit = global_limits.scale(downstream_usage.topology_fraction)
downstream_memory_ok = ExecutionResources(
object_store_memory=downstream_usage.object_store_memory
object_store_memory=downstream_memory
).satisfies_limit(downstream_limit)

# If completing a task decreases the overall object store memory usage, allow it
Expand Down
44 changes: 44 additions & 0 deletions python/ray/data/tests/test_runtime_metrics_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import time

import numpy as np
import pytest

import ray
from ray._private.internal_api import memory_summary
from ray.data._internal.execution.backpressure_policy import (
ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY,
ConcurrencyCapBackpressurePolicy,
)


def test_spam(shutdown_only, restore_data_context):
ctx = ray.init(object_store_memory=100 * 1024**2)

ray.data.DataContext.get_current().use_runtime_metrics_scheduling = True
ray.data.DataContext.get_current().set_config(
ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY, [ConcurrencyCapBackpressurePolicy]
)
ray.data.DataContext.get_current().set_config(
ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY, 1
)
ray.data.DataContext.get_current().set_config(
ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY, 4
)

def f(batch):
time.sleep(0.1)
return {"data": np.zeros(24 * 1024**2, dtype=np.uint8)}

ds = ray.data.range(5, parallelism=5).map_batches(f, batch_size=None)

for _ in ds.iter_batches(batch_size=None, batch_format="pyarrow"):
pass

meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading