diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 44e787643c7b..3813e6aef77c 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -654,8 +654,29 @@ def matches(self, sample: Sample): return True +@dataclass +class PrometheusTimeseries: + """A collection of timeseries from multiple addresses. Each timeseries is a + collection of samples with the same metric name and labels. Concretely: + - components_dict: a dictionary of addresses to the Component labels + - metric_descriptors: a dictionary of metric names to the Metric object + - metric_samples: the latest value of each label + """ + + components_dict: Dict[str, Set[str]] = field(default_factory=dict) + metric_descriptors: Dict[str, Metric] = field(default_factory=dict) + metric_samples: Dict[frozenset, Sample] = field(default_factory=dict) + + def flush(self): + self.components_dict.clear() + self.metric_descriptors.clear() + self.metric_samples.clear() + + def get_metric_check_condition( - metrics_to_check: List[MetricSamplePattern], export_addr: Optional[str] = None + metrics_to_check: List[MetricSamplePattern], + timeseries: PrometheusTimeseries, + export_addr: Optional[str] = None, ) -> Callable[[], bool]: """A condition to check if a prometheus metrics reach a certain value. @@ -665,6 +686,7 @@ def get_metric_check_condition( Args: metrics_to_check: A list of MetricSamplePattern. The fields that aren't `None` will be matched. + timeseries: A PrometheusTimeseries object to store the metrics. export_addr: Optional address to export metrics to. Returns: @@ -677,7 +699,9 @@ def get_metric_check_condition( def f(): for metric_pattern in metrics_to_check: - _, _, metric_samples = fetch_prometheus([prom_addr]) + metric_samples = fetch_prometheus_timeseries( + [prom_addr], timeseries + ).metric_samples.values() for metric_sample in metric_samples: if metric_pattern.matches(metric_sample): break @@ -993,25 +1017,6 @@ def fetch_prometheus(prom_addresses): return components_dict, metric_descriptors, metric_samples -@dataclass -class PrometheusTimeseries: - """A collection of timeseries from multiple addresses. Each timeseries is a - collection of samples with the same metric name and labels. Concretely: - - components_dict: a dictionary of addresses to the Component labels - - metric_descriptors: a dictionary of metric names to the Metric object - - metric_samples: the latest value of each label - """ - - components_dict: Dict[str, Set[str]] = field(default_factory=defaultdict) - metric_descriptors: Dict[str, Metric] = field(default_factory=defaultdict) - metric_samples: Dict[frozenset, Sample] = field(default_factory=defaultdict) - - def flush(self): - self.components_dict.clear() - self.metric_descriptors.clear() - self.metric_samples.clear() - - def fetch_prometheus_timeseries( prom_addreses: List[str], result: PrometheusTimeseries, diff --git a/python/ray/tests/test_autoscaler_e2e.py b/python/ray/tests/test_autoscaler_e2e.py index 5585413e86b9..97f3b4679245 100644 --- a/python/ray/tests/test_autoscaler_e2e.py +++ b/python/ray/tests/test_autoscaler_e2e.py @@ -8,6 +8,7 @@ from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, ) from ray.autoscaler._private.constants import AUTOSCALER_METRIC_PORT @@ -174,6 +175,7 @@ class Foo: def ping(self): return True + timeseries = PrometheusTimeseries() zero_reported_condition = get_metric_check_condition( [ MetricSamplePattern( @@ -199,6 +201,7 @@ def ping(self): partial_label_match={"NodeType": "ray.head.default"}, ), ], + timeseries, export_addr=autoscaler_export_addr, ) wait_for_condition(zero_reported_condition) @@ -239,6 +242,7 @@ def ping(self): partial_label_match={"NodeType": "ray.head.default"}, ), ], + timeseries, export_addr=autoscaler_export_addr, ) wait_for_condition(two_cpu_no_pending_condition) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index 786006b06139..1d3c091f4e1e 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -16,6 +16,7 @@ from ray._private.internal_api import memory_summary from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, object_memory_usage, ) @@ -673,13 +674,18 @@ def start_infeasible(n): # longer timeout is necessary to pass on windows debug/asan builds. timeout = 180 + timeseries = PrometheusTimeseries() wait_for_condition( - get_metric_check_condition([MetricSamplePattern(name=metric_name, value=2)]), + get_metric_check_condition( + [MetricSamplePattern(name=metric_name, value=2)], timeseries + ), timeout=timeout, ) start_infeasible.remote(2) wait_for_condition( - get_metric_check_condition([MetricSamplePattern(name=metric_name, value=3)]), + get_metric_check_condition( + [MetricSamplePattern(name=metric_name, value=3)], timeseries + ), timeout=timeout, ) diff --git a/python/ray/tests/test_scheduling_2.py b/python/ray/tests/test_scheduling_2.py index efe801c2e9ba..ab2034df6325 100644 --- a/python/ray/tests/test_scheduling_2.py +++ b/python/ray/tests/test_scheduling_2.py @@ -12,6 +12,7 @@ from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import ( MetricSamplePattern, + PrometheusTimeseries, get_metric_check_condition, make_global_state_accessor, ) @@ -784,6 +785,7 @@ def ready(self): pg = placement_group(bundles=[{"CPU": 1}], strategy="SPREAD") ray.get(pg.ready()) + timeseries = PrometheusTimeseries() placement_metric_condition = get_metric_check_condition( [ MetricSamplePattern( @@ -802,6 +804,7 @@ def ready(self): partial_label_match={"WorkloadType": "PlacementGroup"}, ), ], + timeseries, ) wait_for_condition(placement_metric_condition, timeout=60)