Skip to content

Commit

Permalink
[data] create StatsManager to manage _StatsActor remote calls (ray-pr…
Browse files Browse the repository at this point in the history
…oject#40913)

Creates a `StatsManager` class to manage remote calls to `_StatsActor`.

This singleton manager controls the time interval for reporting metrics to `_StatsActor`:
- Runs a single background thread that reports metrics to `_StatsActor` every 5s
- This thread is stopped after being inactive for too long, and will be restarted if there is a new update afterwards

Also logs op metrics for `_debug_dump_topology`.

---------

Signed-off-by: Andrew Xue <andewzxue@gmail.com>
  • Loading branch information
Zandew authored and ujjawal-khare committed Nov 29, 2023
1 parent 1c9f85c commit 7b090cd
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 142 deletions.
18 changes: 1 addition & 17 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import time
from contextlib import nullcontext
from typing import Any, Callable, Dict, Iterator, Optional, Tuple

Expand All @@ -16,23 +15,15 @@
resolve_block_refs,
)
from ray.data._internal.memory_tracing import trace_deallocation
from ray.data._internal.stats import (
DatasetStats,
clear_stats_actor_iter_metrics,
update_stats_actor_iter_metrics,
)
from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import make_async_gen
from ray.data.block import Block, BlockMetadata, DataBatch
from ray.data.context import DataContext
from ray.types import ObjectRef

# Interval for metrics update remote calls to _StatsActor during iteration.
STATS_UPDATE_INTERVAL_SECONDS = 30


def iter_batches(
block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
dataset_tag: str,
*,
stats: Optional[DatasetStats] = None,
clear_block_after_read: bool = False,
Expand Down Expand Up @@ -178,9 +169,7 @@ def _async_iter_batches(
# Run everything in a separate thread to not block the main thread when waiting
# for streaming results.
async_batch_iter = make_async_gen(block_refs, fn=_async_iter_batches, num_workers=1)
metrics_tag = {"dataset": dataset_tag}

last_stats_update_time = 0
while True:
with stats.iter_total_blocked_s.timer() if stats else nullcontext():
try:
Expand All @@ -190,11 +179,6 @@ def _async_iter_batches(
with stats.iter_user_s.timer() if stats else nullcontext():
yield next_batch

if time.time() - last_stats_update_time >= STATS_UPDATE_INTERVAL_SECONDS:
update_stats_actor_iter_metrics(stats, metrics_tag)
last_stats_update_time = time.time()
clear_stats_actor_iter_metrics(metrics_tag)


def _format_in_threadpool(
batch_iter: Iterator[Batch],
Expand Down
78 changes: 47 additions & 31 deletions python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@
update_operator_states,
)
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.stats import (
DatasetStats,
clear_stats_actor_metrics,
register_dataset_to_stats_actor,
update_stats_actor_dataset,
update_stats_actor_metrics,
)
from ray.data._internal.stats import DatasetStats, StatsManager
from ray.data.context import DataContext

logger = DatasetLogger(__name__)
Expand All @@ -53,6 +47,9 @@
# progress bar seeming to stall for very large scale workloads.
PROGRESS_BAR_UPDATE_INTERVAL = 50

# Interval for logging execution progress updates and operator metrics.
DEBUG_LOG_INTERVAL_SECONDS = 5

# Visible for testing.
_num_shutdown = 0

Expand Down Expand Up @@ -90,6 +87,8 @@ def __init__(self, options: ExecutionOptions, dataset_tag: str = "unknown_datase
# used for marking when an op has just completed.
self._has_op_completed: Optional[Dict[PhysicalOperator, bool]] = None

self._last_debug_log_time = 0

Executor.__init__(self, options)
thread_name = f"StreamingExecutor-{self._execution_id}"
threading.Thread.__init__(self, daemon=True, name=thread_name)
Expand Down Expand Up @@ -127,8 +126,9 @@ def execute(
self._global_info = ProgressBar("Running", dag.num_outputs_total())

self._output_node: OpState = self._topology[dag]
register_dataset_to_stats_actor(
self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()]
StatsManager.register_dataset_to_stats_actor(
self._dataset_tag,
self._get_operator_tags(),
)
self.start()

Expand Down Expand Up @@ -183,11 +183,14 @@ def shutdown(self, execution_completed: bool = True):
self._shutdown = True
# Give the scheduling loop some time to finish processing.
self.join(timeout=2.0)
update_stats_actor_dataset(
self._dataset_tag,
self._get_state_dict(
state="FINISHED" if execution_completed else "FAILED"
),
self._update_stats_metrics(
state="FINISHED" if execution_completed else "FAILED",
force_update=True,
)
# Clears metrics for this dataset so that they do
# not persist in the grafana dashboard after execution
StatsManager.clear_execution_metrics(
self._dataset_tag, self._get_operator_tags()
)
# Freeze the stats and save it.
self._final_stats = self._generate_stats()
Expand Down Expand Up @@ -223,9 +226,6 @@ def run(self):
finally:
# Signal end of results.
self._output_node.outqueue.append(None)
# Clears metrics for this dataset so that they do
# not persist in the grafana dashboard after execution
clear_stats_actor_metrics(self._get_metrics_tags())

def get_stats(self):
"""Return the stats object for the streaming execution.
Expand Down Expand Up @@ -306,14 +306,12 @@ def _scheduling_loop_step(self, topology: Topology) -> bool:
for op_state in topology.values():
op_state.refresh_progress_bar()

if not DEBUG_TRACE_SCHEDULING:
_debug_dump_topology(topology, log_to_stdout=False)

update_stats_actor_metrics(
[op.metrics for op in self._topology],
self._get_metrics_tags(),
self._get_state_dict(state="RUNNING"),
)
self._update_stats_metrics(state="RUNNING")
if time.time() - self._last_debug_log_time >= DEBUG_LOG_INTERVAL_SECONDS:
_log_op_metrics(topology)
if not DEBUG_TRACE_SCHEDULING:
_debug_dump_topology(topology, log_to_stdout=False)
self._last_debug_log_time = time.time()

# Log metrics of newly completed operators.
for op in topology:
Expand Down Expand Up @@ -365,12 +363,9 @@ def _report_current_usage(
if self._global_info:
self._global_info.set_description(resources_status)

def _get_metrics_tags(self):
"""Returns a list of tags for operator-level metrics."""
return [
{"dataset": self._dataset_tag, "operator": f"{op.name}{i}"}
for i, op in enumerate(self._topology)
]
def _get_operator_tags(self):
"""Returns a list of operator tags."""
return [f"{op.name}{i}" for i, op in enumerate(self._topology)]

def _get_state_dict(self, state):
last_op, last_state = list(self._topology.items())[-1]
Expand All @@ -389,6 +384,15 @@ def _get_state_dict(self, state):
},
}

def _update_stats_metrics(self, state: str, force_update: bool = False):
StatsManager.update_execution_metrics(
self._dataset_tag,
[op.metrics for op in self._topology],
self._get_operator_tags(),
self._get_state_dict(state=state),
force_update=force_update,
)


def _validate_dag(dag: PhysicalOperator, limits: ExecutionResources) -> None:
"""Raises an exception on invalid DAGs.
Expand Down Expand Up @@ -465,3 +469,15 @@ def _debug_dump_topology(topology: Topology, log_to_stdout: bool = True) -> None
f"Blocks Outputted: {state.num_completed_tasks}/{op.num_outputs_total()}"
)
logger.get_logger(log_to_stdout).info("")


def _log_op_metrics(topology: Topology) -> None:
"""Logs the metrics of each operator.
Args:
topology: The topology to debug.
"""
log_str = "Operator Metrics:\n"
for op in topology:
log_str += f"{op.name}: {op.metrics.as_dict()}\n"
logger.get_logger(log_to_stdout=False).info(log_str)
5 changes: 4 additions & 1 deletion python/ray/data/_internal/iterator/iterator_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union

from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import create_dataset_tag
from ray.data.block import Block, BlockMetadata
from ray.data.iterator import DataIterator
from ray.types import ObjectRef
Expand Down Expand Up @@ -58,4 +59,6 @@ def __getattr__(self, name):
raise AttributeError()

def _get_dataset_tag(self):
return (self._base_dataset._plan._dataset_name or "") + self._base_dataset._uuid
return create_dataset_tag(
self._base_dataset._plan._dataset_name, self._base_dataset._uuid
)
8 changes: 8 additions & 0 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.data._internal.execution.operators.output_splitter import OutputSplitter
from ray.data._internal.execution.streaming_executor import StreamingExecutor
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary
from ray.data._internal.util import create_dataset_tag
from ray.data.block import Block, BlockMetadata
from ray.data.iterator import DataIterator
from ray.types import ObjectRef
Expand Down Expand Up @@ -111,6 +112,13 @@ def world_size(self) -> int:
"""Returns the number of splits total."""
return self._world_size

def _get_dataset_tag(self):
return create_dataset_tag(
self._base_dataset._plan._dataset_name,
self._base_dataset._uuid,
self._output_split_idx,
)


@ray.remote(num_cpus=0)
class SplitCoordinator:
Expand Down
13 changes: 8 additions & 5 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
apply_output_blocks_handling_to_read_task,
)
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary
from ray.data._internal.util import capitalize, unify_block_metadata_schema
from ray.data._internal.util import (
capitalize,
create_dataset_tag,
unify_block_metadata_schema,
)
from ray.data.block import Block, BlockMetadata
from ray.data.context import DataContext
from ray.types import ObjectRef
Expand Down Expand Up @@ -532,9 +536,8 @@ def execute_to_iterator(
)
from ray.data._internal.execution.streaming_executor import StreamingExecutor

executor = StreamingExecutor(
copy.deepcopy(ctx.execution_options), self._dataset_uuid
)
metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag)
block_iter = execute_to_legacy_block_iterator(
executor,
self,
Expand Down Expand Up @@ -591,7 +594,7 @@ def execute(
StreamingExecutor,
)

metrics_tag = (self._dataset_name or "dataset") + self._dataset_uuid
metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(
copy.deepcopy(context.execution_options),
metrics_tag,
Expand Down
Loading

0 comments on commit 7b090cd

Please sign in to comment.