@@ -104,8 +119,19 @@ export const JobDetailChartsPage = () => {
+ {data?.datasets && data.datasets.length > 0 && (
+
+
+
+ )}
+
diff --git a/dashboard/client/src/pages/job/TaskProgressBar.tsx b/dashboard/client/src/pages/job/TaskProgressBar.tsx
index 9497b64704c48..0b2039fb18561 100644
--- a/dashboard/client/src/pages/job/TaskProgressBar.tsx
+++ b/dashboard/client/src/pages/job/TaskProgressBar.tsx
@@ -20,6 +20,7 @@ export const TaskProgressBar = ({
numPendingNodeAssignment = 0,
numSubmittedToWorker = 0,
numFailed = 0,
+ numCancelled = 0,
numUnknown = 0,
showAsComplete = false,
showTooltip = true,
@@ -55,6 +56,11 @@ export const TaskProgressBar = ({
value: numPendingArgsAvail,
color: "#f79e02",
},
+ {
+ label: "Cancelled",
+ value: numCancelled,
+ color: theme.palette.grey.A100,
+ },
{
label: "Unknown",
value: numUnknown,
diff --git a/dashboard/client/src/service/data.ts b/dashboard/client/src/service/data.ts
new file mode 100644
index 0000000000000..66169392b9b35
--- /dev/null
+++ b/dashboard/client/src/service/data.ts
@@ -0,0 +1,6 @@
+import { DatasetResponse } from "../type/data";
+import { get } from "./requestHandlers";
+
+export const getDataDatasets = () => {
+ return get("api/data/datasets");
+};
diff --git a/dashboard/client/src/type/data.ts b/dashboard/client/src/type/data.ts
new file mode 100644
index 0000000000000..e02fee5718ce6
--- /dev/null
+++ b/dashboard/client/src/type/data.ts
@@ -0,0 +1,22 @@
+export type DatasetResponse = {
+ datasets: DatasetMetrics[];
+};
+
+export type DatasetMetrics = {
+ dataset: string;
+ state: string;
+ ray_data_current_bytes: {
+ value: number;
+ max: number;
+ };
+ ray_data_output_bytes: {
+ max: number;
+ };
+ ray_data_spilled_bytes: {
+ max: number;
+ };
+ progress: number;
+ total: number;
+ start_time: number;
+ end_time: number | undefined;
+};
diff --git a/dashboard/client/src/type/job.ts b/dashboard/client/src/type/job.ts
index 377bea9a7d9ee..cccb797804633 100644
--- a/dashboard/client/src/type/job.ts
+++ b/dashboard/client/src/type/job.ts
@@ -97,6 +97,7 @@ export type TaskProgress = {
numRunning?: number;
numPendingNodeAssignment?: number;
numFailed?: number;
+ numCancelled?: number;
numUnknown?: number;
};
diff --git a/dashboard/modules/data/__init__.py b/dashboard/modules/data/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/dashboard/modules/data/data_head.py b/dashboard/modules/data/data_head.py
new file mode 100644
index 0000000000000..b846d61accb55
--- /dev/null
+++ b/dashboard/modules/data/data_head.py
@@ -0,0 +1,104 @@
+import json
+import os
+from enum import Enum
+import aiohttp
+from aiohttp.web import Request, Response
+import ray.dashboard.optional_utils as optional_utils
+import ray.dashboard.utils as dashboard_utils
+from ray.dashboard.modules.metrics.metrics_head import (
+ PROMETHEUS_HOST_ENV_VAR,
+ DEFAULT_PROMETHEUS_HOST,
+ PrometheusQueryError,
+)
+from urllib.parse import quote
+import ray
+
+
+MAX_TIME_WINDOW = "1h"
+SAMPLE_RATE = "1s"
+
+
+class PrometheusQuery(Enum):
+ VALUE = ("value", "sum({}) by (dataset)")
+ MAX = (
+ "max",
+ "max_over_time(sum({}) by (dataset)[" + f"{MAX_TIME_WINDOW}:{SAMPLE_RATE}])",
+ )
+
+
+DATASET_METRICS = {
+ "ray_data_output_bytes": (PrometheusQuery.MAX,),
+ "ray_data_spilled_bytes": (PrometheusQuery.MAX,),
+ "ray_data_current_bytes": (PrometheusQuery.VALUE, PrometheusQuery.MAX),
+}
+
+
+class DataHead(dashboard_utils.DashboardHeadModule):
+ def __init__(self, dashboard_head):
+ super().__init__(dashboard_head)
+ self.http_session = aiohttp.ClientSession()
+ self.prometheus_host = os.environ.get(
+ PROMETHEUS_HOST_ENV_VAR, DEFAULT_PROMETHEUS_HOST
+ )
+
+ @optional_utils.DashboardHeadRouteTable.get("/api/data/datasets")
+ @optional_utils.init_ray_and_catch_exceptions()
+ async def get_datasets(self, req: Request) -> Response:
+ try:
+ from ray.data._internal.stats import _get_or_create_stats_actor
+
+ _stats_actor = _get_or_create_stats_actor()
+ datasets = ray.get(_stats_actor.get_datasets.remote())
+ # Initializes dataset metric values
+ for dataset in datasets:
+ for metric, queries in DATASET_METRICS.items():
+ datasets[dataset][metric] = {query.value[0]: 0 for query in queries}
+ # Query dataset metric values from prometheus
+ try:
+ # TODO (Zandew): store results of completed datasets in stats actor.
+ for metric, queries in DATASET_METRICS.items():
+ for query in queries:
+ result = await self._query_prometheus(
+ query.value[1].format(metric)
+ )
+ for res in result["data"]["result"]:
+ dataset, value = res["metric"]["dataset"], res["value"][1]
+ if dataset in datasets:
+ datasets[dataset][metric][query.value[0]] = value
+ except Exception:
+ # Prometheus server may not be running,
+ # leave these values blank and return other data
+ pass
+ # Flatten response
+ datasets = list(
+ map(lambda item: {"dataset": item[0], **item[1]}, datasets.items())
+ )
+ # Sort by descending start time
+ datasets = sorted(datasets, key=lambda x: x["start_time"], reverse=True)
+ return Response(
+ text=json.dumps({"datasets": datasets}),
+ content_type="application/json",
+ )
+ except Exception as e:
+ return Response(
+ status=503,
+ text=str(e),
+ )
+
+ async def run(self, server):
+ pass
+
+ @staticmethod
+ def is_minimal_module():
+ return False
+
+ async def _query_prometheus(self, query):
+ async with self.http_session.get(
+ f"{self.prometheus_host}/api/v1/query?query={quote(query)}"
+ ) as resp:
+ if resp.status == 200:
+ prom_data = await resp.json()
+ return prom_data
+
+ message = await resp.text()
+ raise PrometheusQueryError(resp.status, message)
diff --git a/dashboard/modules/data/tests/test_data_head.py b/dashboard/modules/data/tests/test_data_head.py
new file mode 100644
index 0000000000000..6b44f73f2fcff
--- /dev/null
+++ b/dashboard/modules/data/tests/test_data_head.py
@@ -0,0 +1,48 @@
+import ray
+import requests
+import sys
+import pytest
+
+DATA_HEAD_URLS = {"GET": "http://localhost:8265/api/data/datasets"}
+
+RESPONSE_SCHEMA = [
+ "dataset",
+ "state",
+ "progress",
+ "start_time",
+ "end_time",
+ "total",
+ "ray_data_output_bytes",
+ "ray_data_spilled_bytes",
+ "ray_data_current_bytes",
+]
+
+
+def test_get_datasets():
+ ray.init()
+ ds = ray.data.range(1).map_batches(lambda x: x)
+ ds._set_name("data_head_test")
+ ds.materialize()
+
+ data = requests.get(DATA_HEAD_URLS["GET"]).json()
+
+ assert len(data["datasets"]) == 1
+ assert sorted(data["datasets"][0].keys()) == sorted(RESPONSE_SCHEMA)
+
+ dataset = data["datasets"][0]
+ assert dataset["dataset"].startswith("data_head_test")
+ assert dataset["state"] == "FINISHED"
+ assert dataset["end_time"] is not None
+
+ ds.map_batches(lambda x: x).materialize()
+ data = requests.get(DATA_HEAD_URLS["GET"]).json()
+
+ assert len(data["datasets"]) == 2
+ dataset = data["datasets"][1]
+ assert dataset["dataset"].startswith("data_head_test")
+ assert dataset["state"] == "FINISHED"
+ assert dataset["end_time"] is not None
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py
index 813b45d663627..1d786d051105d 100644
--- a/python/ray/data/_internal/execution/streaming_executor.py
+++ b/python/ray/data/_internal/execution/streaming_executor.py
@@ -37,6 +37,8 @@
from ray.data._internal.stats import (
DatasetStats,
clear_stats_actor_metrics,
+ register_dataset_to_stats_actor,
+ update_stats_actor_dataset,
update_stats_actor_metrics,
)
from ray.data.context import DataContext
@@ -125,6 +127,7 @@ def execute(
self._global_info = ProgressBar("Running", dag.num_outputs_total())
self._output_node: OpState = self._topology[dag]
+ register_dataset_to_stats_actor(self._dataset_tag)
self.start()
class StreamIterator(OutputIterator):
@@ -154,8 +157,8 @@ def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle:
return item
# Needs to be BaseException to catch KeyboardInterrupt. Otherwise we
# can leave dangling progress bars by skipping shutdown.
- except BaseException:
- self._outer.shutdown()
+ except BaseException as e:
+ self._outer.shutdown(isinstance(e, StopIteration))
raise
def __del__(self):
@@ -166,7 +169,7 @@ def __del__(self):
def __del__(self):
self.shutdown()
- def shutdown(self):
+ def shutdown(self, execution_completed: bool = True):
context = DataContext.get_current()
global _num_shutdown
@@ -174,6 +177,13 @@ def shutdown(self):
if self._shutdown:
return
logger.get_logger().debug(f"Shutting down {self}.")
+ update_stats_actor_dataset(
+ self._dataset_tag,
+ {
+ "state": "FINISHED" if execution_completed else "FAILED",
+ "end_time": time.time(),
+ },
+ )
_num_shutdown += 1
self._shutdown = True
# Give the scheduling loop some time to finish processing.
@@ -298,9 +308,15 @@ def _scheduling_loop_step(self, topology: Topology) -> bool:
if not DEBUG_TRACE_SCHEDULING:
_debug_dump_topology(topology, log_to_stdout=False)
+ last_op, last_state = list(topology.items())[-1]
update_stats_actor_metrics(
[op.metrics for op in self._topology],
self._get_metrics_tags(),
+ # TODO (Zandew): report progress at operator level
+ {
+ "progress": last_state.num_completed_tasks,
+ "total": last_op.num_outputs_total(),
+ },
)
# Log metrics of newly completed operators.
diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py
index 033f73affbf1e..e37f0e5841f21 100644
--- a/python/ray/data/_internal/stats.py
+++ b/python/ray/data/_internal/stats.py
@@ -146,6 +146,8 @@ def __init__(self, max_stats=1000):
# Assign dataset uuids with a global counter.
self.next_dataset_id = 0
+ # Dataset metadata to be queried directly by DashboardHead api.
+ self.datasets: Dict[str, Any] = {}
# Ray Data dashboard metrics
# Everything is a gauge because we need to reset all of
@@ -250,6 +252,7 @@ def update_metrics(
self,
op_metrics: List[Dict[str, Union[int, float]]],
tags_list: List[Dict[str, str]],
+ state: Dict[str, Any],
):
for stats, tags in zip(op_metrics, tags_list):
self.bytes_spilled.set(stats.get("obj_store_mem_spilled", 0), tags)
@@ -261,6 +264,8 @@ def update_metrics(
self.gpu_usage.set(stats.get("gpu_usage", 0), tags)
self.block_generation_time.set(stats.get("block_generation_time", 0), tags)
+ self.update_dataset(tags["dataset"], state)
+
def update_iter_metrics(self, stats: "DatasetStats", tags):
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
self.iter_user_s.set(stats.iter_user_s.get(), tags)
@@ -280,6 +285,21 @@ def clear_iter_metrics(self, tags: Dict[str, str]):
self.iter_total_blocked_s.set(0, tags)
self.iter_user_s.set(0, tags)
+ def register_dataset(self, dataset_tag):
+ self.datasets[dataset_tag] = {
+ "state": "RUNNING",
+ "progress": 0,
+ "total": 0,
+ "start_time": time.time(),
+ "end_time": None,
+ }
+
+ def update_dataset(self, dataset_tag, state):
+ self.datasets[dataset_tag].update(state)
+
+ def get_datasets(self):
+ return self.datasets
+
def _get_or_create_stats_actor():
ctx = DataContext.get_current()
@@ -320,13 +340,15 @@ def _check_cluster_stats_actor():
def update_stats_actor_metrics(
- op_metrics: List[OpRuntimeMetrics], tags_list: List[Dict[str, str]]
+ op_metrics: List[OpRuntimeMetrics],
+ tags_list: List[Dict[str, str]],
+ state: Dict[str, Any],
):
global _stats_actor
_check_cluster_stats_actor()
_stats_actor.update_metrics.remote(
- [metric.as_dict() for metric in op_metrics], tags_list
+ [metric.as_dict() for metric in op_metrics], tags_list, state
)
@@ -362,6 +384,18 @@ def get_dataset_id_from_stats_actor() -> str:
return uuid4().hex
+def register_dataset_to_stats_actor(dataset_tag):
+ global _stats_actor
+ _check_cluster_stats_actor()
+ _stats_actor.register_dataset.remote(dataset_tag)
+
+
+def update_stats_actor_dataset(dataset_tag, state):
+ global _stats_actor
+ _check_cluster_stats_actor()
+ _stats_actor.update_dataset.remote(dataset_tag, state)
+
+
class DatasetStats:
"""Holds the execution times for a given Dataset.
diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py
index 0e9c0ec4a54f8..53ce56881e5ed 100644
--- a/python/ray/data/tests/test_stats.py
+++ b/python/ray/data/tests/test_stats.py
@@ -11,7 +11,11 @@
import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.dataset_logger import DatasetLogger
-from ray.data._internal.stats import DatasetStats, _StatsActor
+from ray.data._internal.stats import (
+ DatasetStats,
+ _get_or_create_stats_actor,
+ _StatsActor,
+)
from ray.data.block import BlockMetadata
from ray.data.context import DataContext
from ray.data.tests.util import column_udf
@@ -1314,6 +1318,23 @@ def test_op_state_logging():
assert times_asserted > 0
+def test_stats_actor_datasets(ray_start_cluster):
+ ds = ray.data.range(100, parallelism=20).map_batches(lambda x: x)
+ ds._set_name("test_stats_actor_datasets")
+ ds.materialize()
+ stats_actor = _get_or_create_stats_actor()
+
+ datasets = ray.get(stats_actor.get_datasets.remote())
+ dataset_name = list(filter(lambda x: x.startswith(ds._name), datasets))
+ assert len(dataset_name) == 1
+ dataset_name = dataset_name[0]
+
+ assert datasets[dataset_name]["state"] == "FINISHED"
+ assert datasets[dataset_name]["progress"] == 20
+ assert datasets[dataset_name]["total"] == 20
+ assert datasets[dataset_name]["end_time"] is not None
+
+
if __name__ == "__main__":
import sys