Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,29 @@ def matches(self, sample: Sample):
return True


@dataclass
class PrometheusTimeseries:
"""A collection of timeseries from multiple addresses. Each timeseries is a
collection of samples with the same metric name and labels. Concretely:
- components_dict: a dictionary of addresses to the Component labels
- metric_descriptors: a dictionary of metric names to the Metric object
- metric_samples: the latest value of each label
"""

components_dict: Dict[str, Set[str]] = field(default_factory=dict)
metric_descriptors: Dict[str, Metric] = field(default_factory=dict)
metric_samples: Dict[frozenset, Sample] = field(default_factory=dict)

def flush(self):
self.components_dict.clear()
self.metric_descriptors.clear()
self.metric_samples.clear()


def get_metric_check_condition(
metrics_to_check: List[MetricSamplePattern], export_addr: Optional[str] = None
metrics_to_check: List[MetricSamplePattern],
timeseries: PrometheusTimeseries,
export_addr: Optional[str] = None,
) -> Callable[[], bool]:
"""A condition to check if a prometheus metrics reach a certain value.

Expand All @@ -665,6 +686,7 @@ def get_metric_check_condition(
Args:
metrics_to_check: A list of MetricSamplePattern. The fields that
aren't `None` will be matched.
timeseries: A PrometheusTimeseries object to store the metrics.
export_addr: Optional address to export metrics to.

Returns:
Expand All @@ -677,7 +699,9 @@ def get_metric_check_condition(

def f():
for metric_pattern in metrics_to_check:
_, _, metric_samples = fetch_prometheus([prom_addr])
metric_samples = fetch_prometheus_timeseries(
[prom_addr], timeseries
).metric_samples.values()
for metric_sample in metric_samples:
if metric_pattern.matches(metric_sample):
break
Expand Down Expand Up @@ -993,25 +1017,6 @@ def fetch_prometheus(prom_addresses):
return components_dict, metric_descriptors, metric_samples


@dataclass
class PrometheusTimeseries:
"""A collection of timeseries from multiple addresses. Each timeseries is a
collection of samples with the same metric name and labels. Concretely:
- components_dict: a dictionary of addresses to the Component labels
- metric_descriptors: a dictionary of metric names to the Metric object
- metric_samples: the latest value of each label
"""

components_dict: Dict[str, Set[str]] = field(default_factory=defaultdict)
metric_descriptors: Dict[str, Metric] = field(default_factory=defaultdict)
metric_samples: Dict[frozenset, Sample] = field(default_factory=defaultdict)

def flush(self):
self.components_dict.clear()
self.metric_descriptors.clear()
self.metric_samples.clear()


def fetch_prometheus_timeseries(
prom_addreses: List[str],
result: PrometheusTimeseries,
Expand Down
4 changes: 4 additions & 0 deletions python/ray/tests/test_autoscaler_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray._common.test_utils import SignalActor, wait_for_condition
from ray._private.test_utils import (
MetricSamplePattern,
PrometheusTimeseries,
get_metric_check_condition,
)
from ray.autoscaler._private.constants import AUTOSCALER_METRIC_PORT
Expand Down Expand Up @@ -174,6 +175,7 @@ class Foo:
def ping(self):
return True

timeseries = PrometheusTimeseries()
zero_reported_condition = get_metric_check_condition(
[
MetricSamplePattern(
Expand All @@ -199,6 +201,7 @@ def ping(self):
partial_label_match={"NodeType": "ray.head.default"},
),
],
timeseries,
export_addr=autoscaler_export_addr,
)
wait_for_condition(zero_reported_condition)
Expand Down Expand Up @@ -239,6 +242,7 @@ def ping(self):
partial_label_match={"NodeType": "ray.head.default"},
),
],
timeseries,
export_addr=autoscaler_export_addr,
)
wait_for_condition(two_cpu_no_pending_condition)
Expand Down
10 changes: 8 additions & 2 deletions python/ray/tests/test_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray._private.internal_api import memory_summary
from ray._private.test_utils import (
MetricSamplePattern,
PrometheusTimeseries,
get_metric_check_condition,
object_memory_usage,
)
Expand Down Expand Up @@ -673,13 +674,18 @@ def start_infeasible(n):
# longer timeout is necessary to pass on windows debug/asan builds.
timeout = 180

timeseries = PrometheusTimeseries()
wait_for_condition(
get_metric_check_condition([MetricSamplePattern(name=metric_name, value=2)]),
get_metric_check_condition(
[MetricSamplePattern(name=metric_name, value=2)], timeseries
),
timeout=timeout,
)
start_infeasible.remote(2)
wait_for_condition(
get_metric_check_condition([MetricSamplePattern(name=metric_name, value=3)]),
get_metric_check_condition(
[MetricSamplePattern(name=metric_name, value=3)], timeseries
),
timeout=timeout,
)

Expand Down
3 changes: 3 additions & 0 deletions python/ray/tests/test_scheduling_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray._common.test_utils import SignalActor, wait_for_condition
from ray._private.test_utils import (
MetricSamplePattern,
PrometheusTimeseries,
get_metric_check_condition,
make_global_state_accessor,
)
Expand Down Expand Up @@ -784,6 +785,7 @@ def ready(self):
pg = placement_group(bundles=[{"CPU": 1}], strategy="SPREAD")
ray.get(pg.ready())

timeseries = PrometheusTimeseries()
placement_metric_condition = get_metric_check_condition(
[
MetricSamplePattern(
Expand All @@ -802,6 +804,7 @@ def ready(self):
partial_label_match={"WorkloadType": "PlacementGroup"},
),
],
timeseries,
)
wait_for_condition(placement_metric_condition, timeout=60)

Expand Down