Skip to content

Commit

Permalink
[data] standardize physical operator runtime metrics (#40173)
Browse files Browse the repository at this point in the history
Standardize metrics recording for physical operators. And introduce a new `OpRuntimeMetrics` class to decouple metrics with individual operator implementations. 

Not implemented in this PR:
* There are currently 4 groups of metrics: inputs, outputs, tasks, and object store. The first 2 support all operators. The last 2 only support map operators for now.
* Integration with DatasetStats.

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
  • Loading branch information
raulchen committed Oct 11, 2023
1 parent 4728bdb commit e9ed0f1
Show file tree
Hide file tree
Showing 17 changed files with 452 additions and 139 deletions.
3 changes: 1 addition & 2 deletions python/ray/data/_internal/execution/bulk_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def execute_recursive(op: PhysicalOperator) -> List[RefBundle]:
# Cache and return output.
saved_outputs[op] = output
op_stats = op.get_stats()
op_metrics = op.get_metrics()
if op_stats:
self._stats = builder.build_multistage(op_stats)
self._stats.extra_metrics = op_metrics
self._stats.extra_metrics = op.metrics.as_dict()
stats_summary = self._stats.to_summary()
stats_summary_string = stats_summary.to_string(include_parent=False)
context = DataContext.get_current()
Expand Down
212 changes: 212 additions & 0 deletions python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Optional

import ray
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data._internal.memory_tracing import trace_allocation

if TYPE_CHECKING:
from ray.data._internal.execution.interfaces.physical_operator import (
PhysicalOperator,
)


@dataclass
class RunningTaskInfo:
inputs: RefBundle
num_outputs: int
bytes_outputs: int


@dataclass
class OpRuntimeMetrics:
"""Runtime metrics for a PhysicalOperator.
Metrics are updated dynamically during the execution of the Dataset.
This class can be used for either observablity or scheduling purposes.
DO NOT modify the fields of this class directly. Instead, use the provided
callback methods.
"""

# === Inputs-related metrics ===

# Number of received input blocks.
num_inputs_received: int = 0
# Total size in bytes of received input blocks.
bytes_inputs_received: int = 0

# Number of processed input blocks.
# TODO(hchen): Fields tagged with "map_only" currently only work for MapOperator.
# We should make them work for all operators by unifying the task execution code.
num_inputs_processed: int = field(default=0, metadata={"map_only": True})
# Total size in bytes of processed input blocks.
bytes_inputs_processed: int = field(default=0, metadata={"map_only": True})

# === Outputs-related metrics ===

# Number of generated output blocks.
num_outputs_generated: int = field(default=0, metadata={"map_only": True})
# Total size in bytes of generated output blocks.
bytes_outputs_generated: int = field(default=0, metadata={"map_only": True})

# Number of output blocks that are already taken by the downstream.
num_outputs_taken: int = 0
# Size in bytes of output blocks that are already taken by the downstream.
bytes_outputs_taken: int = 0

# Number of generated output blocks that are from finished tasks.
num_outputs_of_finished_tasks: int = field(default=0, metadata={"map_only": True})
# Size in bytes of generated output blocks that are from finished tasks.
bytes_outputs_of_finished_tasks: int = field(default=0, metadata={"map_only": True})

# === Tasks-related metrics ===

# Number of submitted tasks.
num_tasks_submitted: int = field(default=0, metadata={"map_only": True})
# Number of running tasks.
num_tasks_running: int = field(default=0, metadata={"map_only": True})
# Number of tasks that have at least one output block.
num_tasks_have_outputs: int = field(default=0, metadata={"map_only": True})
# Number of finished tasks.
num_tasks_finished: int = field(default=0, metadata={"map_only": True})

# === Object store memory metrics ===

# Allocated memory size in the object store.
obj_store_mem_alloc: int = field(default=0, metadata={"map_only": True})
# Freed memory size in the object store.
obj_store_mem_freed: int = field(default=0, metadata={"map_only": True})
# Current memory size in the object store.
obj_store_mem_cur: int = field(default=0, metadata={"map_only": True})
# Peak memory size in the object store.
obj_store_mem_peak: int = field(default=0, metadata={"map_only": True})
# Spilled memory size in the object store.
obj_store_mem_spilled: int = field(default=0, metadata={"map_only": True})

def __init__(self, op: "PhysicalOperator"):
from ray.data._internal.execution.operators.map_operator import MapOperator

self._is_map = isinstance(op, MapOperator)
self._running_tasks: Dict[int, RunningTaskInfo] = {}
self._extra_metrics: Dict[str, Any] = {}

@property
def extra_metrics(self) -> Dict[str, Any]:
"""Return a dict of extra metrics."""
return self._extra_metrics

def as_dict(self):
"""Return a dict representation of the metrics."""
result = []
for f in fields(self):
if f.metadata.get("export", True):
if not self._is_map and f.metadata.get("map_only", False):
continue
value = getattr(self, f.name)
result.append((f.name, value))
result.extend(self._extra_metrics.items())
return dict(result)

@property
def average_num_outputs_per_task(self) -> Optional[float]:
"""Average number of output blocks per task, or None if no task has finished."""
if self.num_tasks_finished == 0:
return None
else:
return self.num_outputs_of_finished_tasks / self.num_tasks_finished

@property
def average_bytes_outputs_per_task(self) -> Optional[float]:
"""Average size in bytes of output blocks per task,
or None if no task has finished."""
if self.num_tasks_finished == 0:
return None
else:
return self.bytes_outputs_of_finished_tasks / self.num_tasks_finished

@property
def input_buffer_bytes(self) -> int:
"""Size in bytes of input blocks that are not processed yet."""
return self.bytes_inputs_received - self.bytes_inputs_processed

@property
def output_buffer_bytes(self) -> int:
"""Size in bytes of output blocks that are not taken by the downstream yet."""
return self.bytes_outputs_generated - self.bytes_outputs_taken

def on_input_received(self, input: RefBundle):
"""Callback when the operator receives a new input."""
self.num_inputs_received += 1
input_size = input.size_bytes()
self.bytes_inputs_received += input_size
# Update object store metrics.
self.obj_store_mem_cur += input_size
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

def on_output_taken(self, output: RefBundle):
"""Callback when an output is taken from the operator."""
output_bytes = output.size_bytes()
self.num_outputs_taken += 1
self.bytes_outputs_taken += output_bytes
self.obj_store_mem_cur -= output_bytes

def on_task_submitted(self, task_index: int, inputs: RefBundle):
"""Callback when the operator submits a task."""
self.num_tasks_submitted += 1
self.num_tasks_running += 1
self._running_tasks[task_index] = RunningTaskInfo(inputs, 0, 0)

def on_output_generated(self, task_index: int, output: RefBundle):
"""Callback when a new task generates an output."""
num_outputs = len(output)
output_bytes = output.size_bytes()

self.num_outputs_generated += num_outputs
self.bytes_outputs_generated += output_bytes

task_info = self._running_tasks[task_index]
if task_info.num_outputs == 0:
self.num_tasks_have_outputs += 1
task_info.num_outputs += num_outputs
task_info.bytes_outputs += output_bytes

# Update object store metrics.
self.obj_store_mem_alloc += output_bytes
self.obj_store_mem_cur += output_bytes
if self.obj_store_mem_cur > self.obj_store_mem_peak:
self.obj_store_mem_peak = self.obj_store_mem_cur

for block_ref, _ in output.blocks:
trace_allocation(block_ref, "operator_output")

def on_task_finished(self, task_index: int):
"""Callback when a task is finished."""
self.num_tasks_running -= 1
self.num_tasks_finished += 1

task_info = self._running_tasks[task_index]
self.num_outputs_of_finished_tasks += task_info.num_outputs
self.bytes_outputs_of_finished_tasks += task_info.bytes_outputs

inputs = self._running_tasks[task_index].inputs
self.num_inputs_processed += len(inputs)
total_input_size = inputs.size_bytes()
self.bytes_inputs_processed += total_input_size

blocks = [input[0] for input in inputs.blocks]
metadata = [input[1] for input in inputs.blocks]
ctx = ray.data.context.DataContext.get_current()
if ctx.enable_get_object_locations_for_metrics:
locations = ray.experimental.get_object_locations(blocks)
for block, meta in zip(blocks, metadata):
if locations[block].get("did_spill", False):
assert meta.size_bytes is not None
self.obj_store_mem_spilled += meta.size_bytes

self.obj_store_mem_freed += total_input_size
self.obj_store_mem_cur -= total_input_size

inputs.destroy_if_owned()
del self._running_tasks[task_index]
41 changes: 26 additions & 15 deletions python/ray/data/_internal/execution/interfaces/physical_operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Union

import ray
from .ref_bundle import RefBundle
Expand All @@ -8,6 +8,7 @@
ExecutionOptions,
ExecutionResources,
)
from ray.data._internal.execution.interfaces.op_runtime_metrics import OpRuntimeMetrics
from ray.data._internal.logical.interfaces import Operator
from ray.data._internal.stats import StatsDict

Expand Down Expand Up @@ -59,7 +60,6 @@ def __init__(
self._streaming_gen = streaming_gen
self._output_ready_callback = output_ready_callback
self._task_done_callback = task_done_callback
self._num_output_blocks = 0

def get_waitable(self) -> StreamingObjectRefGenerator:
return self._streaming_gen
Expand Down Expand Up @@ -93,14 +93,6 @@ def on_waitable_ready(self):
RefBundle([(block_ref, meta)], owns_blocks=True)
)

def add_num_output_blocks(self, num_output_blocks):
self._num_output_blocks += num_output_blocks

def get_num_output_blocks(
self,
):
return self._num_output_blocks


class MetadataOpTask(OpTask):
"""Represents an OpTask that only handles metadata, instead of Block data."""
Expand Down Expand Up @@ -163,6 +155,7 @@ def __init__(self, name: str, input_dependencies: List["PhysicalOperator"]):
self._inputs_complete = not input_dependencies
self._dependents_complete = False
self._started = False
self._metrics = OpRuntimeMetrics(self)
self._estimated_output_blocks = None

def __reduce__(self):
Expand All @@ -185,12 +178,15 @@ def get_stats(self) -> StatsDict:
"""Return recorded execution stats for use with DatasetStats."""
raise NotImplementedError

def get_metrics(self) -> Dict[str, int]:
"""Returns dict of metrics reported from this operator.
@property
def metrics(self) -> OpRuntimeMetrics:
"""Returns the runtime metrics of this operator."""
self._metrics._extra_metrics = self._extra_metrics()
return self._metrics

These should be instant values that can be queried at any time, e.g.,
obj_store_mem_allocated, obj_store_mem_freed.
"""
def _extra_metrics(self) -> Dict[str, Any]:
"""Subclasses should override this method to report extra metrics
that are specific to them."""
return {}

def progress_str(self) -> str:
Expand Down Expand Up @@ -241,12 +237,19 @@ def add_input(self, refs: RefBundle, input_index: int) -> None:
Inputs may be added in any order, and calls to `add_input` may be interleaved
with calls to `get_next` / `has_next` to implement streaming execution.
Subclasses should override `_add_input_inner` instead of this method.
Args:
refs: The ref bundle that should be added as input.
input_index: The index identifying the input dependency producing the
input. For most operators, this is always `0` since there is only
one upstream input operator.
"""
self._metrics.on_input_received(refs)
self._add_input_inner(refs, input_index)

def _add_input_inner(self, refs: RefBundle, input_index: int) -> None:
"""Subclasses should override this method to implement `add_input`."""
raise NotImplementedError

def input_done(self, input_index: int) -> None:
Expand Down Expand Up @@ -283,7 +286,15 @@ def get_next(self) -> RefBundle:
"""Get the next downstream output.
It is only allowed to call this if `has_next()` has returned True.
Subclasses should override `_get_next_inner` instead of this method.
"""
output = self._get_next_inner()
self._metrics.on_output_taken(output)
return output

def _get_next_inner(self) -> RefBundle:
"""Subclasses should override this method to implement `get_next`."""
raise NotImplementedError

def get_active_tasks(self) -> List[OpTask]:
Expand Down
3 changes: 3 additions & 0 deletions python/ray/data/_internal/execution/interfaces/ref_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ def __eq__(self, other) -> bool:

def __hash__(self) -> int:
return id(self)

def __len__(self) -> int:
return len(self.blocks)
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def current_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers,
object_store_memory=self._metrics.cur,
object_store_memory=self.metrics.obj_store_mem_cur,
)

def incremental_resource_usage(self) -> ExecutionResources:
Expand All @@ -311,12 +311,12 @@ def incremental_resource_usage(self) -> ExecutionResources:
num_gpus = 0
return ExecutionResources(cpu=num_cpus, gpu=num_gpus)

def get_metrics(self) -> Dict[str, int]:
parent = super().get_metrics()
def _extra_metrics(self) -> Dict[str, Any]:
res = {}
if self._actor_locality_enabled:
parent["locality_hits"] = self._actor_pool._locality_hits
parent["locality_misses"] = self._actor_pool._locality_misses
return parent
res["locality_hits"] = self._actor_pool._locality_hits
res["locality_misses"] = self._actor_pool._locality_misses
return res

@staticmethod
def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def num_outputs_total(self) -> int:
else self.input_dependencies[0].num_outputs_total()
)

def add_input(self, refs: RefBundle, input_index: int) -> None:
def _add_input_inner(self, refs: RefBundle, input_index: int) -> None:
assert not self.completed()
assert input_index == 0, input_index
self._input_buffer.append(refs)
Expand All @@ -93,7 +93,7 @@ def all_inputs_done(self) -> None:
def has_next(self) -> bool:
return len(self._output_buffer) > 0

def get_next(self) -> RefBundle:
def _get_next_inner(self) -> RefBundle:
return self._output_buffer.pop(0)

def get_stats(self) -> StatsDict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def start(self, options: ExecutionOptions) -> None:
def has_next(self) -> bool:
return len(self._input_data) > 0

def get_next(self) -> RefBundle:
def _get_next_inner(self) -> RefBundle:
return self._input_data.pop(0)

def num_outputs_total(self) -> int:
Expand All @@ -62,7 +62,7 @@ def num_outputs_total(self) -> int:
def get_stats(self) -> StatsDict:
return {}

def add_input(self, refs, input_index) -> None:
def _add_input_inner(self, refs, input_index) -> None:
raise ValueError("Inputs are not allowed for this operator.")

def _initialize_metadata(self):
Expand Down
Loading

0 comments on commit e9ed0f1

Please sign in to comment.