From 1268814752a39fdcbaa7184dff4e362e4428ae2a Mon Sep 17 00:00:00 2001 From: Cuong Nguyen Date: Thu, 13 Nov 2025 07:33:53 -0800 Subject: [PATCH 1/2] [core] fix get_metric_check_condition tests Signed-off-by: Cuong Nguyen --- python/ray/_private/test_utils.py | 47 ++++++++++++++----------- python/ray/tests/test_autoscaler_e2e.py | 4 +++ python/ray/tests/test_scheduling.py | 10 ++++-- python/ray/tests/test_scheduling_2.py | 3 ++ 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 44e787643c7b..51f63dd37b64 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -654,8 +654,29 @@ def matches(self, sample: Sample): return True +@dataclass +class PrometheusTimeseries: + """A collection of timeseries from multiple addresses. Each timeseries is a + collection of samples with the same metric name and labels. Concretely: + - components_dict: a dictionary of addresses to the Component labels + - metric_descriptors: a dictionary of metric names to the Metric object + - metric_samples: the latest value of each label + """ + + components_dict: Dict[str, Set[str]] = field(default_factory=defaultdict) + metric_descriptors: Dict[str, Metric] = field(default_factory=defaultdict) + metric_samples: Dict[frozenset, Sample] = field(default_factory=defaultdict) + + def flush(self): + self.components_dict.clear() + self.metric_descriptors.clear() + self.metric_samples.clear() + + def get_metric_check_condition( - metrics_to_check: List[MetricSamplePattern], export_addr: Optional[str] = None + metrics_to_check: List[MetricSamplePattern], + timeseries: PrometheusTimeseries, + export_addr: Optional[str] = None, ) -> Callable[[], bool]: """A condition to check if a prometheus metrics reach a certain value. @@ -665,6 +686,7 @@ def get_metric_check_condition( Args: metrics_to_check: A list of MetricSamplePattern. The fields that aren't `None` will be matched. + timeseries: A PrometheusTimeseries object to store the metrics. export_addr: Optional address to export metrics to. Returns: @@ -677,7 +699,9 @@ def get_metric_check_condition( def f(): for metric_pattern in metrics_to_check: - _, _, metric_samples = fetch_prometheus([prom_addr]) + metric_samples = fetch_prometheus_timeseries( + [prom_addr], timeseries + ).metric_samples.values() for metric_sample in metric_samples: if metric_pattern.matches(metric_sample): break @@ -993,25 +1017,6 @@ def fetch_prometheus(prom_addresses): return components_dict, metric_descriptors, metric_samples -@dataclass -class PrometheusTimeseries: - """A collection of timeseries from multiple addresses. Each timeseries is a - collection of samples with the same metric name and labels. Concretely: - - components_dict: a dictionary of addresses to the Component labels - - metric_descriptors: a dictionary of metric names to the Metric object - - metric_samples: the latest value of each label - """ - - components_dict: Dict[str, Set[str]] = field(default_factory=defaultdict) - metric_descriptors: Dict[str, Metric] = field(default_factory=defaultdict) - metric_samples: Dict[frozenset, Sample] = field(default_factory=defaultdict) - - def flush(self): - self.components_dict.clear() - self.metric_descriptors.clear() - self.metric_samples.clear() - - def fetch_prometheus_timeseries( prom_addreses: List[str], result: PrometheusTimeseries, diff --git a/python/ray/tests/test_autoscaler_e2e.py b/python/ray/tests/test_autoscaler_e2e.py index 5585413e86b9..97f3b4679245 100644 --- a/python/ray/tests/test_autoscaler_e2e.py +++ b/python/ray/tests/test_autoscaler_e2e.py @@ -8,6 +8,7 @@ from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, ) from ray.autoscaler._private.constants import AUTOSCALER_METRIC_PORT @@ -174,6 +175,7 @@ class Foo: def ping(self): return True + timeseries = PrometheusTimeseries() zero_reported_condition = get_metric_check_condition( [ MetricSamplePattern( @@ -199,6 +201,7 @@ def ping(self): partial_label_match={"NodeType": "ray.head.default"}, ), ], + timeseries, export_addr=autoscaler_export_addr, ) wait_for_condition(zero_reported_condition) @@ -239,6 +242,7 @@ def ping(self): partial_label_match={"NodeType": "ray.head.default"}, ), ], + timeseries, export_addr=autoscaler_export_addr, ) wait_for_condition(two_cpu_no_pending_condition) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index 786006b06139..1d3c091f4e1e 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -16,6 +16,7 @@ from ray._private.internal_api import memory_summary from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, object_memory_usage, ) @@ -673,13 +674,18 @@ def start_infeasible(n): # longer timeout is necessary to pass on windows debug/asan builds. timeout = 180 + timeseries = PrometheusTimeseries() wait_for_condition( - get_metric_check_condition([MetricSamplePattern(name=metric_name, value=2)]), + get_metric_check_condition( + [MetricSamplePattern(name=metric_name, value=2)], timeseries + ), timeout=timeout, ) start_infeasible.remote(2) wait_for_condition( - get_metric_check_condition([MetricSamplePattern(name=metric_name, value=3)]), + get_metric_check_condition( + [MetricSamplePattern(name=metric_name, value=3)], timeseries + ), timeout=timeout, ) diff --git a/python/ray/tests/test_scheduling_2.py b/python/ray/tests/test_scheduling_2.py index efe801c2e9ba..ab2034df6325 100644 --- a/python/ray/tests/test_scheduling_2.py +++ b/python/ray/tests/test_scheduling_2.py @@ -12,6 +12,7 @@ from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, make_global_state_accessor, ) @@ -784,6 +785,7 @@ def ready(self): pg = placement_group(bundles=[{"CPU": 1}], strategy="SPREAD") ray.get(pg.ready()) + timeseries = PrometheusTimeseries() placement_metric_condition = get_metric_check_condition( [ MetricSamplePattern( @@ -802,6 +804,7 @@ def ready(self): partial_label_match={"WorkloadType": "PlacementGroup"}, ), ], + timeseries, ) wait_for_condition(placement_metric_condition, timeout=60) From e7e756062cd13a8466faef9acb8ac3696513e778 Mon Sep 17 00:00:00 2001 From: Cuong Nguyen <128072568+can-anyscale@users.noreply.github.com> Date: Thu, 13 Nov 2025 08:33:24 -0800 Subject: [PATCH 2/2] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Cuong Nguyen <128072568+can-anyscale@users.noreply.github.com> --- python/ray/_private/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 51f63dd37b64..3813e6aef77c 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -663,9 +663,9 @@ class PrometheusTimeseries: - metric_samples: the latest value of each label """ - components_dict: Dict[str, Set[str]] = field(default_factory=defaultdict) - metric_descriptors: Dict[str, Metric] = field(default_factory=defaultdict) - metric_samples: Dict[frozenset, Sample] = field(default_factory=defaultdict) + components_dict: Dict[str, Set[str]] = field(default_factory=dict) + metric_descriptors: Dict[str, Metric] = field(default_factory=dict) + metric_samples: Dict[frozenset, Sample] = field(default_factory=dict) def flush(self): self.components_dict.clear()