Skip to content

Commit

Permalink
[data] store bytes spilled/restored after plan execution (#39361)
Browse files Browse the repository at this point in the history
Adds bytes spilled/restored to DatasetStats after a plan finishes execution.

Refactors internal_api a bit so that we can get the actual reply instead of a prettified string.

First part of #38847. Next steps would be to report spilling of individual blocks from ray.

---------

Signed-off-by: Andrew Xue <andrewxue@anyscale.com>
  • Loading branch information
Zandew committed Sep 13, 2023
1 parent 8d80377 commit ed32450
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 16 deletions.
25 changes: 16 additions & 9 deletions python/ray/_private/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def global_gc():
worker.core_worker.global_gc()


def get_state_from_address(address=None):
address = services.canonicalize_bootstrap_address_or_die(address)

state = GlobalState()
options = GcsClientOptions.from_gcs_address(address)
state._initialize_global_state(options)
return state


def memory_summary(
address=None,
redis_password=ray_constants.REDIS_DEFAULT_PASSWORD,
Expand All @@ -30,20 +39,18 @@ def memory_summary(
):
from ray.dashboard.memory_utils import memory_summary

address = services.canonicalize_bootstrap_address_or_die(address)
state = get_state_from_address(address)
reply = get_memory_info_reply(state)

state = GlobalState()
options = GcsClientOptions.from_gcs_address(address)
state._initialize_global_state(options)
if stats_only:
return get_store_stats(state)
return store_stats_summary(reply)
return memory_summary(
state, group_by, sort_by, line_wrap, units, num_entries
) + get_store_stats(state)
) + store_stats_summary(reply)


def get_store_stats(state, node_manager_address=None, node_manager_port=None):
"""Returns a formatted string describing memory usage in the cluster."""
def get_memory_info_reply(state, node_manager_address=None, node_manager_port=None):
"""Returns global memory info."""

from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc

Expand Down Expand Up @@ -76,7 +83,7 @@ def get_store_stats(state, node_manager_address=None, node_manager_port=None):
node_manager_pb2.FormatGlobalMemoryInfoRequest(include_memory_info=False),
timeout=60.0,
)
return store_stats_summary(reply)
return reply


def node_stats(
Expand Down
12 changes: 12 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

import ray
from ray._private.internal_api import get_memory_info_reply, get_state_from_address
from ray.data._internal.block_list import BlockList
from ray.data._internal.compute import (
ActorPoolStrategy,
Expand Down Expand Up @@ -635,6 +636,17 @@ def execute(
stats_summary_string,
)

# Retrieve memory-related stats from ray.
reply = get_memory_info_reply(
get_state_from_address(ray.get_runtime_context().gcs_address)
)
if reply.store_stats.spill_time_total_s > 0:
stats.global_bytes_spilled = int(reply.store_stats.spilled_bytes_total)
if reply.store_stats.restore_time_total_s > 0:
stats.global_bytes_restored = int(
reply.store_stats.restored_bytes_total
)

# Set the snapshot to the output of the final stage.
self._snapshot_blocks = blocks
self._snapshot_stats = stats
Expand Down
29 changes: 27 additions & 2 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def __init__(
self.iter_blocks_remote: int = 0
self.iter_unknown_location: int = 0

# Memory usage stats
self.global_bytes_spilled: int = 0
self.global_bytes_restored: int = 0

@property
def stats_actor(self):
return _get_or_create_stats_actor()
Expand Down Expand Up @@ -336,6 +340,8 @@ def to_summary(self) -> "DatasetStatsSummary":
self.time_total_s,
self.base_name,
self.extra_metrics,
self.global_bytes_spilled,
self.global_bytes_restored,
)


Expand All @@ -350,9 +356,14 @@ class DatasetStatsSummary:
time_total_s: float
base_name: str
extra_metrics: Dict[str, Any]
global_bytes_spilled: int
global_bytes_restored: int

def to_string(
self, already_printed: Optional[Set[str]] = None, include_parent: bool = True
self,
already_printed: Optional[Set[str]] = None,
include_parent: bool = True,
add_global_stats=True,
) -> str:
"""Return a human-readable summary of this Dataset's stats.
Expand All @@ -370,7 +381,7 @@ def to_string(
out = ""
if self.parents and include_parent:
for p in self.parents:
parent_sum = p.to_string(already_printed)
parent_sum = p.to_string(already_printed, add_global_stats=False)
if parent_sum:
out += parent_sum
out += "\n"
Expand Down Expand Up @@ -407,6 +418,18 @@ def to_string(
out += indent
out += "* Extra metrics: " + str(self.extra_metrics) + "\n"
out += str(self.iter_stats)

mb_spilled = round(self.global_bytes_spilled / 1e6)
mb_restored = round(self.global_bytes_restored / 1e6)
if (
len(self.stages_stats) > 0
and add_global_stats
and (mb_spilled or mb_restored)
):
out += "\nCluster memory:\n"
out += "* Spilled to disk: {}MB\n".format(mb_spilled)
out += "* Restored from disk: {}MB\n".format(mb_restored)

return out

def __repr__(self, level=0) -> str:
Expand All @@ -430,6 +453,8 @@ def __repr__(self, level=0) -> str:
f"{indent} extra_metrics={{{extra_metrics}}},\n"
f"{indent} stage_stats=[{stage_stats}],\n"
f"{indent} iter_stats={self.iter_stats.__repr__(level+1)},\n"
f"{indent} global_bytes_spilled={self.global_bytes_spilled / 1e6}MB,\n"
f"{indent} global_bytes_restored={self.global_bytes_restored / 1e6}MB,\n"
f"{indent} parents=[{parent_stats}],\n"
f"{indent})"
)
Expand Down
28 changes: 27 additions & 1 deletion python/ray/data/tests/test_object_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_iter_batches_no_spilling_upon_no_transformation(shutdown_only):
ctx = ray.init(num_cpus=1, object_store_memory=300e6)
# The size of dataset is 500*(80*80*4)*8B, about 100MB.
ds = ray.data.range_tensor(500, shape=(80, 80, 4), parallelism=100)

check_no_spill(ctx, ds.repeat())
check_no_spill(ctx, ds.window(blocks_per_window=20))

Expand Down Expand Up @@ -235,6 +234,33 @@ def consume(p):
assert "Spilled" not in meminfo, meminfo


def test_global_bytes_spilled(shutdown_only):
# The object store is about 90MB.
ctx = ray.init(object_store_memory=90e6)
# The size of dataset is 500*(80*80*4)*8B, about 100MB.
ds = ray.data.range_tensor(500, shape=(80, 80, 4), parallelism=100).materialize()

with pytest.raises(AssertionError):
check_no_spill(ctx, ds.repeat())
assert ds._get_stats_summary().global_bytes_spilled > 0
assert ds._get_stats_summary().global_bytes_restored > 0

assert "Spilled to disk:" in ds.stats()


def test_no_global_bytes_spilled(shutdown_only):
# The object store is about 200MB.
ctx = ray.init(object_store_memory=200e6)
# The size of dataset is 500*(80*80*4)*8B, about 100MB.
ds = ray.data.range_tensor(500, shape=(80, 80, 4), parallelism=100).materialize()

check_no_spill(ctx, ds.repeat())
assert ds._get_stats_summary().global_bytes_spilled == 0
assert ds._get_stats_summary().global_bytes_restored == 0

assert "Cluster memory:" not in ds.stats()


if __name__ == "__main__":
import sys

Expand Down
16 changes: 12 additions & 4 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def canonicalize(stats: str) -> str:
s0 = re.sub("([a-f\d]{32})", "U", stats)
# Time expressions.
s1 = re.sub("[0-9\.]+(ms|us|s)", "T", s0)
# Memory expressions.
s2 = re.sub("[0-9\.]+(B|MB|GB)", "M", s1)
# Handle zero values specially so we can check for missing values.
s2 = re.sub(" [0]+(\.[0]+)?", " Z", s1)
s3 = re.sub(" [0]+(\.[0]+)?", " Z", s2)
# Other numerics.
s3 = re.sub("[0-9]+(\.[0-9]+)?", "N", s2)
s4 = re.sub("[0-9]+(\.[0-9]+)?", "N", s3)
# Replace tabs with spaces.
s4 = re.sub("\t", " ", s3)
return s4
s5 = re.sub("\t", " ", s4)
return s5


def dummy_map_batches(x):
Expand Down Expand Up @@ -416,6 +418,8 @@ def test_dataset__repr__(ray_start_regular_shared):
" user_time=T,\n"
" total_time=T,\n"
" ),\n"
" global_bytes_spilled=M,\n"
" global_bytes_restored=M,\n"
" parents=[],\n"
")"
)
Expand Down Expand Up @@ -474,6 +478,8 @@ def check_stats():
" user_time=T,\n"
" total_time=T,\n"
" ),\n"
" global_bytes_spilled=M,\n"
" global_bytes_restored=M,\n"
" parents=[\n"
" DatasetStatsSummary(\n"
" dataset_uuid=U,\n"
Expand Down Expand Up @@ -505,6 +511,8 @@ def check_stats():
" user_time=T,\n"
" total_time=T,\n"
" ),\n"
" global_bytes_spilled=M,\n"
" global_bytes_restored=M,\n"
" parents=[],\n"
" ),\n"
" ],\n"
Expand Down

0 comments on commit ed32450

Please sign in to comment.