From cb3bdaf26dca8d8a6a0f8eb4dd92ba451240615c Mon Sep 17 00:00:00 2001 From: Andrew Xue Date: Tue, 14 Nov 2023 19:00:13 -0800 Subject: [PATCH 1/4] stats manager Signed-off-by: Andrew Xue --- .../_internal/block_batching/iter_batches.py | 18 +- .../_internal/execution/streaming_executor.py | 54 +++-- .../data/_internal/iterator/iterator_impl.py | 5 +- .../iterator/stream_split_iterator.py | 8 + python/ray/data/_internal/plan.py | 13 +- python/ray/data/_internal/stats.py | 217 +++++++++++++----- python/ray/data/_internal/util.py | 7 + python/ray/data/dataset.py | 8 +- python/ray/data/iterator.py | 11 +- python/ray/data/tests/test_stats.py | 78 ++++++- 10 files changed, 295 insertions(+), 124 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 0a87c9b3e949d..b01be673cfe96 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -1,5 +1,4 @@ import collections -import time from contextlib import nullcontext from typing import Any, Callable, Dict, Iterator, Optional, Tuple @@ -16,23 +15,15 @@ resolve_block_refs, ) from ray.data._internal.memory_tracing import trace_deallocation -from ray.data._internal.stats import ( - DatasetStats, - clear_stats_actor_iter_metrics, - update_stats_actor_iter_metrics, -) +from ray.data._internal.stats import DatasetStats from ray.data._internal.util import make_async_gen from ray.data.block import Block, BlockMetadata, DataBatch from ray.data.context import DataContext from ray.types import ObjectRef -# Interval for metrics update remote calls to _StatsActor during iteration. -STATS_UPDATE_INTERVAL_SECONDS = 30 - def iter_batches( block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]], - dataset_tag: str, *, stats: Optional[DatasetStats] = None, clear_block_after_read: bool = False, @@ -178,9 +169,7 @@ def _async_iter_batches( # Run everything in a separate thread to not block the main thread when waiting # for streaming results. async_batch_iter = make_async_gen(block_refs, fn=_async_iter_batches, num_workers=1) - metrics_tag = {"dataset": dataset_tag} - last_stats_update_time = 0 while True: with stats.iter_total_blocked_s.timer() if stats else nullcontext(): try: @@ -190,11 +179,6 @@ def _async_iter_batches( with stats.iter_user_s.timer() if stats else nullcontext(): yield next_batch - if time.time() - last_stats_update_time >= STATS_UPDATE_INTERVAL_SECONDS: - update_stats_actor_iter_metrics(stats, metrics_tag) - last_stats_update_time = time.time() - clear_stats_actor_iter_metrics(metrics_tag) - def _format_in_threadpool( batch_iter: Iterator[Batch], diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 49e4fecf8492b..6ad43d095d534 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -34,13 +34,7 @@ update_operator_states, ) from ray.data._internal.progress_bar import ProgressBar -from ray.data._internal.stats import ( - DatasetStats, - clear_stats_actor_metrics, - register_dataset_to_stats_actor, - update_stats_actor_dataset, - update_stats_actor_metrics, -) +from ray.data._internal.stats import DatasetStats, StatsManager from ray.data.context import DataContext logger = DatasetLogger(__name__) @@ -127,9 +121,7 @@ def execute( self._global_info = ProgressBar("Running", dag.num_outputs_total()) self._output_node: OpState = self._topology[dag] - register_dataset_to_stats_actor( - self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()] - ) + StatsManager.register_dataset_to_stats_actor(self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()]) self.start() class StreamIterator(OutputIterator): @@ -183,12 +175,13 @@ def shutdown(self, execution_completed: bool = True): self._shutdown = True # Give the scheduling loop some time to finish processing. self.join(timeout=2.0) - update_stats_actor_dataset( - self._dataset_tag, - self._get_state_dict( - state="FINISHED" if execution_completed else "FAILED" - ), + self._update_stats_metrics( + state="FINISHED" if execution_completed else "FAILED", + force_update=True, ) + # Clears metrics for this dataset so that they do + # not persist in the grafana dashboard after execution + StatsManager.clear_stats_actor_metrics(self._get_metrics_tags()) # Freeze the stats and save it. self._final_stats = self._generate_stats() stats_summary_string = self._final_stats.to_summary().to_string( @@ -223,9 +216,6 @@ def run(self): finally: # Signal end of results. self._output_node.outqueue.append(None) - # Clears metrics for this dataset so that they do - # not persist in the grafana dashboard after execution - clear_stats_actor_metrics(self._get_metrics_tags()) def get_stats(self): """Return the stats object for the streaming execution. @@ -306,15 +296,11 @@ def _scheduling_loop_step(self, topology: Topology) -> bool: for op_state in topology.values(): op_state.refresh_progress_bar() + self._update_stats_metrics(state="RUNNING") + _log_op_metrics(topology) if not DEBUG_TRACE_SCHEDULING: _debug_dump_topology(topology, log_to_stdout=False) - update_stats_actor_metrics( - [op.metrics for op in self._topology], - self._get_metrics_tags(), - self._get_state_dict(state="RUNNING"), - ) - # Log metrics of newly completed operators. for op in topology: if op.completed() and not self._has_op_completed[op]: @@ -389,6 +375,14 @@ def _get_state_dict(self, state): }, } + def _update_stats_metrics(self, state: str, force_update: bool = False): + StatsManager.update_stats_actor_metrics( + [op.metrics for op in self._topology], + self._get_metrics_tags(), + self._get_state_dict(state=state), + force_update=force_update, + ) + def _validate_dag(dag: PhysicalOperator, limits: ExecutionResources) -> None: """Raises an exception on invalid DAGs. @@ -465,3 +459,15 @@ def _debug_dump_topology(topology: Topology, log_to_stdout: bool = True) -> None f"Blocks Outputted: {state.num_completed_tasks}/{op.num_outputs_total()}" ) logger.get_logger(log_to_stdout).info("") + + +def _log_op_metrics(topology: Topology) -> None: + """Logs the metrics of each operator. + + Args: + topology: The topology to debug. + """ + log_str = "Operator Metrics:\n" + for op in topology: + log_str += f"{op.name}: {op.metrics.as_dict()}\n" + logger.get_logger(log_to_stdout=False).info(log_str) diff --git a/python/ray/data/_internal/iterator/iterator_impl.py b/python/ray/data/_internal/iterator/iterator_impl.py index 55a3cccc16254..9d408485cd548 100644 --- a/python/ray/data/_internal/iterator/iterator_impl.py +++ b/python/ray/data/_internal/iterator/iterator_impl.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union from ray.data._internal.stats import DatasetStats +from ray.data._internal.util import create_dataset_tag from ray.data.block import Block, BlockMetadata from ray.data.iterator import DataIterator from ray.types import ObjectRef @@ -58,4 +59,6 @@ def __getattr__(self, name): raise AttributeError() def _get_dataset_tag(self): - return (self._base_dataset._plan._dataset_name or "") + self._base_dataset._uuid + return create_dataset_tag( + self._base_dataset._plan._dataset_name, self._base_dataset._uuid + ) diff --git a/python/ray/data/_internal/iterator/stream_split_iterator.py b/python/ray/data/_internal/iterator/stream_split_iterator.py index d5ab948c06e27..bda9b115bc81f 100644 --- a/python/ray/data/_internal/iterator/stream_split_iterator.py +++ b/python/ray/data/_internal/iterator/stream_split_iterator.py @@ -11,6 +11,7 @@ from ray.data._internal.execution.operators.output_splitter import OutputSplitter from ray.data._internal.execution.streaming_executor import StreamingExecutor from ray.data._internal.stats import DatasetStats, DatasetStatsSummary +from ray.data._internal.util import create_dataset_tag from ray.data.block import Block, BlockMetadata from ray.data.iterator import DataIterator from ray.types import ObjectRef @@ -111,6 +112,13 @@ def world_size(self) -> int: """Returns the number of splits total.""" return self._world_size + def _get_dataset_tag(self): + return create_dataset_tag( + self._base_dataset._plan._dataset_name, + self._base_dataset._uuid, + self._output_split_idx, + ) + @ray.remote(num_cpus=0) class SplitCoordinator: diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index f038fba1d3c05..b6151722b17b7 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -39,7 +39,11 @@ apply_output_blocks_handling_to_read_task, ) from ray.data._internal.stats import DatasetStats, DatasetStatsSummary -from ray.data._internal.util import capitalize, unify_block_metadata_schema +from ray.data._internal.util import ( + capitalize, + create_dataset_tag, + unify_block_metadata_schema, +) from ray.data.block import Block, BlockMetadata from ray.data.context import DataContext from ray.types import ObjectRef @@ -532,9 +536,8 @@ def execute_to_iterator( ) from ray.data._internal.execution.streaming_executor import StreamingExecutor - executor = StreamingExecutor( - copy.deepcopy(ctx.execution_options), self._dataset_uuid - ) + metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid) + executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag) block_iter = execute_to_legacy_block_iterator( executor, self, @@ -591,7 +594,7 @@ def execute( StreamingExecutor, ) - metrics_tag = (self._dataset_name or "dataset") + self._dataset_uuid + metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid) executor = StreamingExecutor( copy.deepcopy(context.execution_options), metrics_tag, diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 0a09d8002405e..9d7ffe1a7e0f7 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1,8 +1,9 @@ import collections +import threading import time from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from uuid import uuid4 import numpy as np @@ -20,6 +21,9 @@ STATS_ACTOR_NAME = "datasets_stats_actor" STATS_ACTOR_NAMESPACE = "_dataset_stats_actor" +# Interval for making remote calls to the _StatsActor. +STATS_ACTOR_UPDATE_INTERVAL_SECONDS = 5 + StatsDict = Dict[str, List[BlockMetadata]] @@ -248,7 +252,13 @@ def get_dataset_id(self): self.next_dataset_id += 1 return dataset_id - def update_metrics( + def update_metrics(self, execution_metrics, iteration_metrics): + for metrics in execution_metrics: + self.update_execution_metrics(*metrics) + for metrics in iteration_metrics: + self.update_iter_metrics(*metrics) + + def update_execution_metrics( self, op_metrics: List[Dict[str, Union[int, float]]], tags_list: List[Dict[str, str]], @@ -268,7 +278,11 @@ def update_metrics( # so all tags should contain the same dataset self.update_dataset(tags_list[0]["dataset"], state) - def update_iter_metrics(self, stats: "DatasetStats", tags): + def update_iter_metrics( + self, + stats: "DatasetStats", + tags: Dict[str, str], + ): self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) @@ -330,80 +344,163 @@ def _get_or_create_stats_actor(): ).remote() -_stats_actor: Optional[_StatsActor] = None -_stats_actor_cluster_id: Optional[str] = None -"""This global _stats_actor may be from a previous cluster that was shutdown. -We store _cluster_id to check that the stored actor exists in the current cluster. -""" - - -def _check_cluster_stats_actor(): - # Checks if global _stats_actor belongs to current cluster, - # if not, creates a new one on the current cluster. - global _stats_actor, _stats_actor_cluster_id - if ray._private.worker._global_node is None: - raise RuntimeError("Global node is not initialized.") - current_cluster_id = ray._private.worker._global_node.cluster_id - if _stats_actor is None or _stats_actor_cluster_id != current_cluster_id: - _stats_actor = _get_or_create_stats_actor() - _stats_actor_cluster_id = current_cluster_id +class _StatsManager: + """A Class containing util functions that manage remote calls to _StatsActor. + This class collects stats from execution and iteration codepaths and keeps + track of the latest snapshot. -def update_stats_actor_metrics( - op_metrics: List[OpRuntimeMetrics], - tags_list: List[Dict[str, str]], - state: Dict[str, Any], -): - global _stats_actor - _check_cluster_stats_actor() - - _stats_actor.update_metrics.remote( - [metric.as_dict() for metric in op_metrics], tags_list, state - ) + An instance of this class runs a single background thread that periodically + forwards the latest execution/iteration stats to the _StatsActor. + This thread will terminate itself after being inactive (meaning that there are + no active executors or iterators) for STATS_ACTOR_UPDATE_THREAD_INACTIVITY_LIMIT + iterations. After terminating, a new thread will start if more calls are made + to this class. + """ -def update_stats_actor_iter_metrics(stats: "DatasetStats", tags_list: Dict[str, str]): - global _stats_actor - _check_cluster_stats_actor() + def __init__(self): + # Lazily get stats actor handle to avoid circular import. + self._stats_actor_handle = None + self._stats_actor_cluster_id = None + + # Last execution stats snapshots + self._last_execution_stats = {} + # Last iteration stats snapshots + self._last_iter_stats: Dict[str, Tuple[Dict[str, str], "DatasetStats"]] = {} + # Lock for updating stats snapshots + self._stats_lock: threading.Lock = threading.Lock() + + # Background thread to make remote calls to _StatsActor + self._iter_update_thread: Optional[threading.Thread] = None + self._iter_update_thread_lock: threading.Lock = threading.Lock() + + def _stats_actor(self, create_if_not_exists=True) -> _StatsActor: + if ray._private.worker._global_node is None: + raise RuntimeError("Global node is not initialized.") + current_cluster_id = ray._private.worker._global_node.cluster_id + if ( + self._stats_actor_handle is None + or self._stats_actor_cluster_id != current_cluster_id + ): + self._stats_actor_cluster_id = current_cluster_id + if create_if_not_exists: + self._stats_actor_handle = _get_or_create_stats_actor() + else: + self._stat_actor_handle = ray.get_actor( + name=STATS_ACTOR_NAME, namespace=STATS_ACTOR_NAMESPACE + ) + return self._stats_actor_handle + + # After this many iterations of inactivity, + # _StatsManager._iter_update_thread will close itself. + UPDATE_THREAD_INACTIVITY_LIMIT = 5 + + def _start_thread_if_not_running(self): + # Start background update thread if not running. + with self._iter_update_thread_lock: + if ( + self._iter_update_thread is None + or not self._iter_update_thread.is_alive() + ): + + def _run_update_loop(): + iter_stats_inactivity = 0 + while True: + if self._last_iter_stats or self._last_execution_stats: + try: + # Do not create _StatsActor if it doesn't exist because + # this thread can be running even after the cluster is + # shutdown. Creating an actor will automatically start + # a new cluster. + self._stats_actor( + create_if_not_exists=False + ).update_metrics.remote( + execution_metrics=list( + self._last_execution_stats.values() + ), + iteration_metrics=list( + self._last_iter_stats.values() + ), + ) + iter_stats_inactivity = 0 + except Exception: + return + else: + iter_stats_inactivity += 1 + if ( + iter_stats_inactivity + >= _StatsManager.UPDATE_THREAD_INACTIVITY_LIMIT + ): + return + time.sleep(STATS_ACTOR_UPDATE_INTERVAL_SECONDS) + + self._iter_update_thread = threading.Thread( + target=_run_update_loop, daemon=True + ) + self._iter_update_thread.start() - _stats_actor.update_iter_metrics.remote(stats, tags_list) + # Execution methods + def update_stats_actor_metrics( + self, + op_metrics: List[OpRuntimeMetrics], + tags_list: List[Dict[str, str]], + state: Dict[str, Any], + force_update: bool = False, + ): + op_metrics_dicts = [metric.as_dict() for metric in op_metrics] + if force_update: + self._stats_actor().update_execution_metrics.remote( + op_metrics_dicts, tags_list, state + ) + else: + with self._stats_lock: + self._last_execution_stats[tags_list[0]["dataset"]] = ( + op_metrics_dicts, + tags_list, + state, + ) + self._start_thread_if_not_running() -def clear_stats_actor_metrics(tags_list: List[Dict[str, str]]): - global _stats_actor - _check_cluster_stats_actor() + def clear_stats_actor_metrics(self, tags_list: List[Dict[str, str]]): + with self._stats_lock: + if tags_list[0]["dataset"] in self._last_execution_stats: + del self._last_execution_stats[tags_list[0]["dataset"]] - _stats_actor.clear_metrics.remote(tags_list) + self._stats_actor().clear_metrics.remote(tags_list) + # Iteration methods -def clear_stats_actor_iter_metrics(tags: Dict[str, str]): - global _stats_actor - _check_cluster_stats_actor() + def update_stats_actor_iter_metrics( + self, stats: "DatasetStats", tags: Dict[str, str] + ): + with self._stats_lock: + self._last_iter_stats[tags["dataset"]] = (stats, tags) + self._start_thread_if_not_running() - _stats_actor.clear_iter_metrics.remote(tags) + def clear_stats_actor_iter_metrics(self, tags: Dict[str, str]): + with self._stats_lock: + if tags["dataset"] in self._last_iter_stats: + del self._last_iter_stats[tags["dataset"]] + self._stats_actor().clear_iter_metrics.remote(tags) -def get_dataset_id_from_stats_actor() -> str: - global _stats_actor - try: - _check_cluster_stats_actor() - return ray.get(_stats_actor.get_dataset_id.remote()) - except Exception: - # Getting dataset id from _StatsActor may fail, in this case - # fall back to uuid4 - return uuid4().hex + # Other methods + def register_dataset_to_stats_actor(self, dataset_tag, operator_tags): + self._stats_actor().register_dataset.remote(dataset_tag, operator_tags) -def register_dataset_to_stats_actor(dataset_tag: str, operator_tags: List[str]): - global _stats_actor - _check_cluster_stats_actor() - _stats_actor.register_dataset.remote(dataset_tag, operator_tags) + def get_dataset_id_from_stats_actor(self) -> str: + try: + return ray.get(self._stats_actor().get_dataset_id.remote()) + except Exception: + # Getting dataset id from _StatsActor may fail, in this case + # fall back to uuid4 + return uuid4().hex -def update_stats_actor_dataset(dataset_tag: str, state: Dict[str, Any]): - global _stats_actor - _check_cluster_stats_actor() - _stats_actor.update_dataset.remote(dataset_tag, state) +StatsManager = _StatsManager() class DatasetStats: diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 759632717d36a..188c382e4212b 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -864,3 +864,10 @@ def execute_computation(thread_index: int): num_threads_alive = num_workers - num_threads_finished if num_threads_alive > 0: output_queue.release(num_threads_alive) + + +def create_dataset_tag(dataset_name: Optional[str], *args): + tag = dataset_name or "dataset" + for arg in args: + tag += f"_{arg}" + return tag diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index aa5ea9ff06b78..2343e441f60c3 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -83,11 +83,7 @@ SortStage, ZipStage, ) -from ray.data._internal.stats import ( - DatasetStats, - DatasetStatsSummary, - get_dataset_id_from_stats_actor, -) +from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager from ray.data._internal.util import ( AllToAllAPI, ConsumptionAPI, @@ -256,7 +252,7 @@ def __init__( self._current_executor: Optional["Executor"] = None self._write_ds = None - self._set_uuid(get_dataset_id_from_stats_actor()) + self._set_uuid(StatsManager.get_dataset_id_from_stats_actor()) @staticmethod def copy( diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 3d7b06b929b5a..9e46c422c432a 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -17,7 +17,7 @@ import numpy as np from ray.data._internal.block_batching.iter_batches import iter_batches -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, StatsManager from ray.data.block import ( Block, BlockAccessor, @@ -166,7 +166,6 @@ def _create_iterator() -> Iterator[DataBatch]: iterator = iter( iter_batches( block_iterator, - dataset_tag=self._get_dataset_tag(), stats=stats, clear_block_after_read=blocks_owned_by_consumer, batch_size=batch_size, @@ -180,8 +179,11 @@ def _create_iterator() -> Iterator[DataBatch]: ) ) + metrics_tag = {"dataset": self._get_dataset_tag()} for batch in iterator: yield batch + StatsManager.update_stats_actor_iter_metrics(stats, metrics_tag) + StatsManager.clear_stats_actor_iter_metrics(metrics_tag) if stats: stats.iter_total_s.add(time.perf_counter() - time_start) @@ -847,6 +849,11 @@ def iter_epochs(self, max_epoch: int = -1) -> None: "iter_torch_batches(), or to_tf()." ) + def __del__(self): + StatsManager.clear_stats_actor_iter_metrics( + {"dataset": self._get_dataset_tag()} + ) + # Backwards compatibility alias. DatasetIterator = DataIterator diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 6cfe090355f53..fb9ef1466edae 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1,4 +1,5 @@ import re +import threading import time from collections import Counter from contextlib import contextmanager @@ -13,9 +14,11 @@ from ray.data._internal.dataset_logger import DatasetLogger from ray.data._internal.stats import ( DatasetStats, + StatsManager, _get_or_create_stats_actor, _StatsActor, ) +from ray.data._internal.util import create_dataset_tag from ray.data.block import BlockMetadata from ray.data.context import DataContext from ray.data.tests.util import column_udf @@ -147,7 +150,7 @@ def enable_get_object_locations_flag(): @contextmanager def patch_update_stats_actor(): with patch( - "ray.data._internal.execution.streaming_executor.update_stats_actor_metrics" + "ray.data._internal.stats.StatsManager.update_stats_actor_metrics" ) as update_fn: yield update_fn @@ -155,8 +158,10 @@ def patch_update_stats_actor(): @contextmanager def patch_update_stats_actor_iter(): with patch( - "ray.data._internal.block_batching.iter_batches.update_stats_actor_iter_metrics" - ) as update_fn: + "ray.data._internal.stats.StatsManager.update_stats_actor_iter_metrics" + ) as update_fn, patch( + "ray.data._internal.stats.StatsManager.clear_stats_actor_iter_metrics" + ): yield update_fn @@ -1197,7 +1202,7 @@ def test_stats_actor_metrics(): assert final_metric.obj_store_mem_cur == 0 tags = update_fn.call_args_list[-1].args[1] - assert all([tag["dataset"] == "dataset" + ds._uuid for tag in tags]) + assert all([tag["dataset"] == f"dataset_{ds._uuid}" for tag in tags]) assert tags[0]["operator"] == "Input0" assert tags[1]["operator"] == "ReadRange->MapBatches()1" @@ -1223,7 +1228,7 @@ def test_stats_actor_iter_metrics(): final_stats = update_fn.call_args_list[-1].args[0] assert final_stats == ds_stats - assert ds._uuid == update_fn.call_args_list[-1].args[1]["dataset"] + assert f"dataset_{ds._uuid}" == update_fn.call_args_list[-1].args[1]["dataset"] def test_dataset_name(): @@ -1238,7 +1243,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == "test_ds" + mds._uuid + assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_{mds._uuid}" # Names persist after an execution ds = ds.random_shuffle() @@ -1246,7 +1251,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == "test_ds" + mds._uuid + assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_{mds._uuid}" ds._set_name("test_ds_two") ds = ds.map_batches(lambda x: x) @@ -1255,7 +1260,7 @@ def test_dataset_name(): mds = ds.materialize() assert ( - update_fn.call_args_list[-1].args[1][0]["dataset"] == "test_ds_two" + mds._uuid + update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_two_{mds._uuid}" ) ds._set_name(None) @@ -1264,7 +1269,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == "dataset" + mds._uuid + assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"dataset_{mds._uuid}" ds = ray.data.range(100, parallelism=20) ds._set_name("very_loooooooong_name") @@ -1344,6 +1349,61 @@ def test_stats_actor_datasets(ray_start_cluster): assert value["state"] == "FINISHED" +@patch.object(ray.data._internal.stats, "STATS_ACTOR_UPDATE_INTERVAL_SECONDS", new=0.5) +@patch.object(StatsManager, "_stats_actor_handle") +@patch.object(StatsManager, "UPDATE_THREAD_INACTIVITY_LIMIT", new=1) +def test_stats_manager(shutdown_only): + num_threads = 10 + + datasets = [None] * num_threads + # Mock clear methods so that _last_execution_stats and _last_iter_stats + # are not cleared. We will assert on them afterwards. + with patch.object(StatsManager, "clear_stats_actor_metrics"), patch.object( + StatsManager, "clear_stats_actor_iter_metrics" + ): + + def update_stats_manager(i): + datasets[i] = ray.data.range(1e6).map_batches(lambda x: x) + for _ in datasets[i].iter_batches(batch_size=100): + pass + + threads = [ + threading.Thread(target=update_stats_manager, args=(i,), daemon=True) + for i in range(num_threads) + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert len(StatsManager._last_execution_stats) == num_threads + assert len(StatsManager._last_iter_stats) == num_threads + + # Clear dataset tags manually. + for dataset in datasets: + dataset_tag = create_dataset_tag(dataset._name, dataset._uuid) + assert dataset_tag in StatsManager._last_execution_stats + assert dataset_tag in StatsManager._last_iter_stats + StatsManager.clear_stats_actor_metrics( + [ + {"dataset": dataset_tag, "operator": "Input0"}, + { + "dataset": dataset_tag, + "operator": "ReadRange->MapBatches()1", + }, + ] + ) + StatsManager.clear_stats_actor_iter_metrics({"dataset": dataset_tag}) + + wait_for_condition(lambda: not StatsManager._iter_update_thread.is_alive()) + prev_thread = StatsManager._iter_update_thread + + ray.data.range(1e6).map_batches(lambda x: x).materialize() + # Check that a new different thread is spawned. + assert StatsManager._iter_update_thread != prev_thread + wait_for_condition(lambda: not StatsManager._iter_update_thread.is_alive()) + + if __name__ == "__main__": import sys From cf81dfd4c18f478db97687b7c2016aca784712a9 Mon Sep 17 00:00:00 2001 From: Andrew Xue Date: Tue, 14 Nov 2023 19:47:33 -0800 Subject: [PATCH 2/4] fix Signed-off-by: Andrew Xue --- python/ray/data/_internal/execution/streaming_executor.py | 4 +++- python/ray/data/tests/test_stats.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 6ad43d095d534..769e2e57a622b 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -121,7 +121,9 @@ def execute( self._global_info = ProgressBar("Running", dag.num_outputs_total()) self._output_node: OpState = self._topology[dag] - StatsManager.register_dataset_to_stats_actor(self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()]) + StatsManager.register_dataset_to_stats_actor( + self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()] + ) self.start() class StreamIterator(OutputIterator): diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index fb9ef1466edae..e4c5a8e3e701f 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1353,6 +1353,7 @@ def test_stats_actor_datasets(ray_start_cluster): @patch.object(StatsManager, "_stats_actor_handle") @patch.object(StatsManager, "UPDATE_THREAD_INACTIVITY_LIMIT", new=1) def test_stats_manager(shutdown_only): + ray.init() num_threads = 10 datasets = [None] * num_threads From 60cde64c1fdaf17f9d3ccd0bb5988a980f29c0bb Mon Sep 17 00:00:00 2001 From: Andrew Xue Date: Tue, 14 Nov 2023 21:04:59 -0800 Subject: [PATCH 3/4] comment Signed-off-by: Andrew Xue --- python/ray/data/iterator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 9e46c422c432a..3742fac48c2a8 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -850,6 +850,7 @@ def iter_epochs(self, max_epoch: int = -1) -> None: ) def __del__(self): + # Clear metrics on deletion in case the iterator was not fully consumed. StatsManager.clear_stats_actor_iter_metrics( {"dataset": self._get_dataset_tag()} ) From 3872a6da8005ff8fd30518c65a89c28d4db3c987 Mon Sep 17 00:00:00 2001 From: Andrew Xue Date: Fri, 17 Nov 2023 13:49:23 -0800 Subject: [PATCH 4/4] fix Signed-off-by: Andrew Xue --- .../_internal/execution/streaming_executor.py | 34 +++-- python/ray/data/_internal/stats.py | 118 ++++++++++-------- python/ray/data/iterator.py | 10 +- python/ray/data/tests/test_stats.py | 68 +++++----- 4 files changed, 121 insertions(+), 109 deletions(-) diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 769e2e57a622b..ed9aadd9c3364 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -47,6 +47,9 @@ # progress bar seeming to stall for very large scale workloads. PROGRESS_BAR_UPDATE_INTERVAL = 50 +# Interval for logging execution progress updates and operator metrics. +DEBUG_LOG_INTERVAL_SECONDS = 5 + # Visible for testing. _num_shutdown = 0 @@ -84,6 +87,8 @@ def __init__(self, options: ExecutionOptions, dataset_tag: str = "unknown_datase # used for marking when an op has just completed. self._has_op_completed: Optional[Dict[PhysicalOperator, bool]] = None + self._last_debug_log_time = 0 + Executor.__init__(self, options) thread_name = f"StreamingExecutor-{self._execution_id}" threading.Thread.__init__(self, daemon=True, name=thread_name) @@ -122,7 +127,8 @@ def execute( self._output_node: OpState = self._topology[dag] StatsManager.register_dataset_to_stats_actor( - self._dataset_tag, [tag["operator"] for tag in self._get_metrics_tags()] + self._dataset_tag, + self._get_operator_tags(), ) self.start() @@ -183,7 +189,9 @@ def shutdown(self, execution_completed: bool = True): ) # Clears metrics for this dataset so that they do # not persist in the grafana dashboard after execution - StatsManager.clear_stats_actor_metrics(self._get_metrics_tags()) + StatsManager.clear_execution_metrics( + self._dataset_tag, self._get_operator_tags() + ) # Freeze the stats and save it. self._final_stats = self._generate_stats() stats_summary_string = self._final_stats.to_summary().to_string( @@ -299,9 +307,11 @@ def _scheduling_loop_step(self, topology: Topology) -> bool: op_state.refresh_progress_bar() self._update_stats_metrics(state="RUNNING") - _log_op_metrics(topology) - if not DEBUG_TRACE_SCHEDULING: - _debug_dump_topology(topology, log_to_stdout=False) + if time.time() - self._last_debug_log_time >= DEBUG_LOG_INTERVAL_SECONDS: + _log_op_metrics(topology) + if not DEBUG_TRACE_SCHEDULING: + _debug_dump_topology(topology, log_to_stdout=False) + self._last_debug_log_time = time.time() # Log metrics of newly completed operators. for op in topology: @@ -353,12 +363,9 @@ def _report_current_usage( if self._global_info: self._global_info.set_description(resources_status) - def _get_metrics_tags(self): - """Returns a list of tags for operator-level metrics.""" - return [ - {"dataset": self._dataset_tag, "operator": f"{op.name}{i}"} - for i, op in enumerate(self._topology) - ] + def _get_operator_tags(self): + """Returns a list of operator tags.""" + return [f"{op.name}{i}" for i, op in enumerate(self._topology)] def _get_state_dict(self, state): last_op, last_state = list(self._topology.items())[-1] @@ -378,9 +385,10 @@ def _get_state_dict(self, state): } def _update_stats_metrics(self, state: str, force_update: bool = False): - StatsManager.update_stats_actor_metrics( + StatsManager.update_execution_metrics( + self._dataset_tag, [op.metrics for op in self._topology], - self._get_metrics_tags(), + self._get_operator_tags(), self._get_state_dict(state=state), force_update=force_update, ) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 9d7ffe1a7e0f7..eca1c8c906d4d 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -10,6 +10,7 @@ import ray from ray.data._internal.block_list import BlockList +from ray.data._internal.dataset_logger import DatasetLogger from ray.data._internal.execution.interfaces.op_runtime_metrics import OpRuntimeMetrics from ray.data._internal.util import capfirst from ray.data.block import BlockMetadata @@ -18,11 +19,11 @@ from ray.util.metrics import Gauge from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +logger = DatasetLogger(__name__) + STATS_ACTOR_NAME = "datasets_stats_actor" STATS_ACTOR_NAMESPACE = "_dataset_stats_actor" -# Interval for making remote calls to the _StatsActor. -STATS_ACTOR_UPDATE_INTERVAL_SECONDS = 5 StatsDict = Dict[str, List[BlockMetadata]] @@ -256,15 +257,17 @@ def update_metrics(self, execution_metrics, iteration_metrics): for metrics in execution_metrics: self.update_execution_metrics(*metrics) for metrics in iteration_metrics: - self.update_iter_metrics(*metrics) + self.update_iteration_metrics(*metrics) def update_execution_metrics( self, + dataset_tag: str, op_metrics: List[Dict[str, Union[int, float]]], - tags_list: List[Dict[str, str]], + operator_tags: List[str], state: Dict[str, Any], ): - for stats, tags in zip(op_metrics, tags_list): + for stats, operator_tag in zip(op_metrics, operator_tags): + tags = self._create_tags(dataset_tag, operator_tag) self.bytes_spilled.set(stats.get("obj_store_mem_spilled", 0), tags) self.bytes_allocated.set(stats.get("obj_store_mem_alloc", 0), tags) self.bytes_freed.set(stats.get("obj_store_mem_freed", 0), tags) @@ -276,18 +279,20 @@ def update_execution_metrics( # This update is called from a dataset's executor, # so all tags should contain the same dataset - self.update_dataset(tags_list[0]["dataset"], state) + self.update_dataset(dataset_tag, state) - def update_iter_metrics( + def update_iteration_metrics( self, stats: "DatasetStats", - tags: Dict[str, str], + dataset_tag, ): + tags = self._create_tags(dataset_tag) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) - def clear_metrics(self, tags_list: List[Dict[str, str]]): - for tags in tags_list: + def clear_execution_metrics(self, dataset_tag: str, operator_tags: List[str]): + for operator_tag in operator_tags: + tags = self._create_tags(dataset_tag, operator_tag) self.bytes_spilled.set(0, tags) self.bytes_allocated.set(0, tags) self.bytes_freed.set(0, tags) @@ -297,7 +302,8 @@ def clear_metrics(self, tags_list: List[Dict[str, str]]): self.gpu_usage.set(0, tags) self.block_generation_time.set(0, tags) - def clear_iter_metrics(self, tags: Dict[str, str]): + def clear_iteration_metrics(self, dataset_tag: str): + tags = self._create_tags(dataset_tag) self.iter_total_blocked_s.set(0, tags) self.iter_user_s.set(0, tags) @@ -324,6 +330,12 @@ def update_dataset(self, dataset_tag, state): def get_datasets(self): return self.datasets + def _create_tags(self, dataset_tag: str, operator_tag: Optional[str] = None): + tags = {"dataset": dataset_tag} + if operator_tag is not None: + tags["operator"] = operator_tag + return tags + def _get_or_create_stats_actor(): ctx = DataContext.get_current() @@ -359,21 +371,30 @@ class _StatsManager: to this class. """ + # Interval for making remote calls to the _StatsActor. + STATS_ACTOR_UPDATE_INTERVAL_SECONDS = 5 + + # After this many iterations of inactivity, + # _StatsManager._update_thread will close itself. + UPDATE_THREAD_INACTIVITY_LIMIT = 5 + def __init__(self): # Lazily get stats actor handle to avoid circular import. self._stats_actor_handle = None self._stats_actor_cluster_id = None - # Last execution stats snapshots + # Last execution stats snapshots for all executing datasets self._last_execution_stats = {} - # Last iteration stats snapshots - self._last_iter_stats: Dict[str, Tuple[Dict[str, str], "DatasetStats"]] = {} + # Last iteration stats snapshots for all running iterators + self._last_iteration_stats: Dict[ + str, Tuple[Dict[str, str], "DatasetStats"] + ] = {} # Lock for updating stats snapshots self._stats_lock: threading.Lock = threading.Lock() # Background thread to make remote calls to _StatsActor - self._iter_update_thread: Optional[threading.Thread] = None - self._iter_update_thread_lock: threading.Lock = threading.Lock() + self._update_thread: Optional[threading.Thread] = None + self._update_thread_lock: threading.Lock = threading.Lock() def _stats_actor(self, create_if_not_exists=True) -> _StatsActor: if ray._private.worker._global_node is None: @@ -392,22 +413,15 @@ def _stats_actor(self, create_if_not_exists=True) -> _StatsActor: ) return self._stats_actor_handle - # After this many iterations of inactivity, - # _StatsManager._iter_update_thread will close itself. - UPDATE_THREAD_INACTIVITY_LIMIT = 5 - def _start_thread_if_not_running(self): # Start background update thread if not running. - with self._iter_update_thread_lock: - if ( - self._iter_update_thread is None - or not self._iter_update_thread.is_alive() - ): + with self._update_thread_lock: + if self._update_thread is None or not self._update_thread.is_alive(): def _run_update_loop(): iter_stats_inactivity = 0 while True: - if self._last_iter_stats or self._last_execution_stats: + if self._last_iteration_stats or self._last_execution_stats: try: # Do not create _StatsActor if it doesn't exist because # this thread can be running even after the cluster is @@ -420,11 +434,14 @@ def _run_update_loop(): self._last_execution_stats.values() ), iteration_metrics=list( - self._last_iter_stats.values() + self._last_iteration_stats.values() ), ) iter_stats_inactivity = 0 except Exception: + logger.get_logger(log_to_stdout=False).exception( + "Error occurred during remote call to _StatsActor." + ) return else: iter_stats_inactivity += 1 @@ -432,59 +449,56 @@ def _run_update_loop(): iter_stats_inactivity >= _StatsManager.UPDATE_THREAD_INACTIVITY_LIMIT ): + logger.get_logger(log_to_stdout=False).info( + "Terminating StatsManager thread due to inactivity." + ) return - time.sleep(STATS_ACTOR_UPDATE_INTERVAL_SECONDS) + time.sleep(StatsManager.STATS_ACTOR_UPDATE_INTERVAL_SECONDS) - self._iter_update_thread = threading.Thread( + self._update_thread = threading.Thread( target=_run_update_loop, daemon=True ) - self._iter_update_thread.start() + self._update_thread.start() # Execution methods - def update_stats_actor_metrics( + def update_execution_metrics( self, + dataset_tag: str, op_metrics: List[OpRuntimeMetrics], - tags_list: List[Dict[str, str]], + operator_tags: List[str], state: Dict[str, Any], force_update: bool = False, ): op_metrics_dicts = [metric.as_dict() for metric in op_metrics] + args = (dataset_tag, op_metrics_dicts, operator_tags, state) if force_update: - self._stats_actor().update_execution_metrics.remote( - op_metrics_dicts, tags_list, state - ) + self._stats_actor().update_execution_metrics.remote(*args) else: with self._stats_lock: - self._last_execution_stats[tags_list[0]["dataset"]] = ( - op_metrics_dicts, - tags_list, - state, - ) + self._last_execution_stats[dataset_tag] = args self._start_thread_if_not_running() - def clear_stats_actor_metrics(self, tags_list: List[Dict[str, str]]): + def clear_execution_metrics(self, dataset_tag: str, operator_tags: List[str]): with self._stats_lock: - if tags_list[0]["dataset"] in self._last_execution_stats: - del self._last_execution_stats[tags_list[0]["dataset"]] + if dataset_tag in self._last_execution_stats: + del self._last_execution_stats[dataset_tag] - self._stats_actor().clear_metrics.remote(tags_list) + self._stats_actor().clear_execution_metrics.remote(dataset_tag, operator_tags) # Iteration methods - def update_stats_actor_iter_metrics( - self, stats: "DatasetStats", tags: Dict[str, str] - ): + def update_iteration_metrics(self, stats: "DatasetStats", dataset_tag: str): with self._stats_lock: - self._last_iter_stats[tags["dataset"]] = (stats, tags) + self._last_iteration_stats[dataset_tag] = (stats, dataset_tag) self._start_thread_if_not_running() - def clear_stats_actor_iter_metrics(self, tags: Dict[str, str]): + def clear_iteration_metrics(self, dataset_tag: str): with self._stats_lock: - if tags["dataset"] in self._last_iter_stats: - del self._last_iter_stats[tags["dataset"]] + if dataset_tag in self._last_iteration_stats: + del self._last_iteration_stats[dataset_tag] - self._stats_actor().clear_iter_metrics.remote(tags) + self._stats_actor().clear_iteration_metrics.remote(dataset_tag) # Other methods diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 3742fac48c2a8..794b6ea0100a1 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -179,11 +179,11 @@ def _create_iterator() -> Iterator[DataBatch]: ) ) - metrics_tag = {"dataset": self._get_dataset_tag()} + dataset_tag = self._get_dataset_tag() for batch in iterator: yield batch - StatsManager.update_stats_actor_iter_metrics(stats, metrics_tag) - StatsManager.clear_stats_actor_iter_metrics(metrics_tag) + StatsManager.update_iteration_metrics(stats, dataset_tag) + StatsManager.clear_iteration_metrics(dataset_tag) if stats: stats.iter_total_s.add(time.perf_counter() - time_start) @@ -851,9 +851,7 @@ def iter_epochs(self, max_epoch: int = -1) -> None: def __del__(self): # Clear metrics on deletion in case the iterator was not fully consumed. - StatsManager.clear_stats_actor_iter_metrics( - {"dataset": self._get_dataset_tag()} - ) + StatsManager.clear_iteration_metrics(self._get_dataset_tag()) # Backwards compatibility alias. diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index e4c5a8e3e701f..75e510f68c0b5 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -150,7 +150,7 @@ def enable_get_object_locations_flag(): @contextmanager def patch_update_stats_actor(): with patch( - "ray.data._internal.stats.StatsManager.update_stats_actor_metrics" + "ray.data._internal.stats.StatsManager.update_execution_metrics" ) as update_fn: yield update_fn @@ -158,9 +158,9 @@ def patch_update_stats_actor(): @contextmanager def patch_update_stats_actor_iter(): with patch( - "ray.data._internal.stats.StatsManager.update_stats_actor_iter_metrics" + "ray.data._internal.stats.StatsManager.update_iteration_metrics" ) as update_fn, patch( - "ray.data._internal.stats.StatsManager.clear_stats_actor_iter_metrics" + "ray.data._internal.stats.StatsManager.clear_iteration_metrics" ): yield update_fn @@ -1186,7 +1186,7 @@ def test_stats_actor_metrics(): ds = ray.data.range(1000 * 80 * 80 * 4).map_batches(lambda x: x).materialize() # last emitted metrics from map operator - final_metric = update_fn.call_args_list[-1].args[0][-1] + final_metric = update_fn.call_args_list[-1].args[1][-1] assert final_metric.obj_store_mem_spilled == ds._plan.stats().dataset_bytes_spilled assert ( @@ -1201,10 +1201,10 @@ def test_stats_actor_metrics(): # There should be nothing in object store at the end of execution. assert final_metric.obj_store_mem_cur == 0 - tags = update_fn.call_args_list[-1].args[1] - assert all([tag["dataset"] == f"dataset_{ds._uuid}" for tag in tags]) - assert tags[0]["operator"] == "Input0" - assert tags[1]["operator"] == "ReadRange->MapBatches()1" + args = update_fn.call_args_list[-1].args + assert args[0] == f"dataset_{ds._uuid}" + assert args[2][0] == "Input0" + assert args[2][1] == "ReadRange->MapBatches()1" def sleep_three(x): import time @@ -1215,7 +1215,7 @@ def sleep_three(x): with patch_update_stats_actor() as update_fn: ds = ray.data.range(3).map_batches(sleep_three, batch_size=1).materialize() - final_metric = update_fn.call_args_list[-1].args[0][-1] + final_metric = update_fn.call_args_list[-1].args[1][-1] assert final_metric.block_generation_time >= 9 @@ -1228,7 +1228,7 @@ def test_stats_actor_iter_metrics(): final_stats = update_fn.call_args_list[-1].args[0] assert final_stats == ds_stats - assert f"dataset_{ds._uuid}" == update_fn.call_args_list[-1].args[1]["dataset"] + assert f"dataset_{ds._uuid}" == update_fn.call_args_list[-1].args[1] def test_dataset_name(): @@ -1243,7 +1243,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_{mds._uuid}" + assert update_fn.call_args_list[-1].args[0] == f"test_ds_{mds._uuid}" # Names persist after an execution ds = ds.random_shuffle() @@ -1251,7 +1251,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_{mds._uuid}" + assert update_fn.call_args_list[-1].args[0] == f"test_ds_{mds._uuid}" ds._set_name("test_ds_two") ds = ds.map_batches(lambda x: x) @@ -1259,9 +1259,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert ( - update_fn.call_args_list[-1].args[1][0]["dataset"] == f"test_ds_two_{mds._uuid}" - ) + assert update_fn.call_args_list[-1].args[0] == f"test_ds_two_{mds._uuid}" ds._set_name(None) ds = ds.map_batches(lambda x: x) @@ -1269,7 +1267,7 @@ def test_dataset_name(): with patch_update_stats_actor() as update_fn: mds = ds.materialize() - assert update_fn.call_args_list[-1].args[1][0]["dataset"] == f"dataset_{mds._uuid}" + assert update_fn.call_args_list[-1].args[0] == f"dataset_{mds._uuid}" ds = ray.data.range(100, parallelism=20) ds._set_name("very_loooooooong_name") @@ -1349,7 +1347,7 @@ def test_stats_actor_datasets(ray_start_cluster): assert value["state"] == "FINISHED" -@patch.object(ray.data._internal.stats, "STATS_ACTOR_UPDATE_INTERVAL_SECONDS", new=0.5) +@patch.object(StatsManager, "STATS_ACTOR_UPDATE_INTERVAL_SECONDS", new=0.5) @patch.object(StatsManager, "_stats_actor_handle") @patch.object(StatsManager, "UPDATE_THREAD_INACTIVITY_LIMIT", new=1) def test_stats_manager(shutdown_only): @@ -1357,15 +1355,15 @@ def test_stats_manager(shutdown_only): num_threads = 10 datasets = [None] * num_threads - # Mock clear methods so that _last_execution_stats and _last_iter_stats + # Mock clear methods so that _last_execution_stats and _last_iteration_stats # are not cleared. We will assert on them afterwards. - with patch.object(StatsManager, "clear_stats_actor_metrics"), patch.object( - StatsManager, "clear_stats_actor_iter_metrics" + with patch.object(StatsManager, "clear_execution_metrics"), patch.object( + StatsManager, "clear_iteration_metrics" ): def update_stats_manager(i): - datasets[i] = ray.data.range(1e6).map_batches(lambda x: x) - for _ in datasets[i].iter_batches(batch_size=100): + datasets[i] = ray.data.range(10).map_batches(lambda x: x) + for _ in datasets[i].iter_batches(batch_size=1): pass threads = [ @@ -1378,31 +1376,25 @@ def update_stats_manager(i): thread.join() assert len(StatsManager._last_execution_stats) == num_threads - assert len(StatsManager._last_iter_stats) == num_threads + assert len(StatsManager._last_iteration_stats) == num_threads # Clear dataset tags manually. for dataset in datasets: dataset_tag = create_dataset_tag(dataset._name, dataset._uuid) assert dataset_tag in StatsManager._last_execution_stats - assert dataset_tag in StatsManager._last_iter_stats - StatsManager.clear_stats_actor_metrics( - [ - {"dataset": dataset_tag, "operator": "Input0"}, - { - "dataset": dataset_tag, - "operator": "ReadRange->MapBatches()1", - }, - ] + assert dataset_tag in StatsManager._last_iteration_stats + StatsManager.clear_execution_metrics( + dataset_tag, ["Input0", "ReadRange->MapBatches()1"] ) - StatsManager.clear_stats_actor_iter_metrics({"dataset": dataset_tag}) + StatsManager.clear_iteration_metrics(dataset_tag) - wait_for_condition(lambda: not StatsManager._iter_update_thread.is_alive()) - prev_thread = StatsManager._iter_update_thread + wait_for_condition(lambda: not StatsManager._update_thread.is_alive()) + prev_thread = StatsManager._update_thread - ray.data.range(1e6).map_batches(lambda x: x).materialize() + ray.data.range(100).map_batches(lambda x: x).materialize() # Check that a new different thread is spawned. - assert StatsManager._iter_update_thread != prev_thread - wait_for_condition(lambda: not StatsManager._iter_update_thread.is_alive()) + assert StatsManager._update_thread != prev_thread + wait_for_condition(lambda: not StatsManager._update_thread.is_alive()) if __name__ == "__main__":