diff --git a/diag_master/src/diag_master/main.py b/diag_master/src/diag_master/main.py index f71521d..b7d58eb 100644 --- a/diag_master/src/diag_master/main.py +++ b/diag_master/src/diag_master/main.py @@ -15,6 +15,7 @@ from ros2_transport.server import Ros2Server from sync_graph.sync_graph import SyncGraph from sync_graph.timed_graph_update_queue import TimedGraphUpdateQueue +from sync_graph.update_aggregator import aggregate_clock_diff_measurements from sync_graph.yaml import to_sync_graph_args from sync_tooling_msgs.graph_update_pb2 import GraphUpdate @@ -79,7 +80,9 @@ def shutdown(self): @property def sync_graph(self): sg = self._sync_graph_factory() - for u in self._update_queue.updates: + updates = self._update_queue.updates + aggregated_updates = aggregate_clock_diff_measurements(updates) + for u in aggregated_updates: sg.update(u) return sg diff --git a/sync_graph/src/sync_graph/update_aggregator.py b/sync_graph/src/sync_graph/update_aggregator.py new file mode 100644 index 0000000..7c422ee --- /dev/null +++ b/sync_graph/src/sync_graph/update_aggregator.py @@ -0,0 +1,72 @@ +""" +Some update types might be high-frequency, or might be noisy. +This module provides functions to aggregate such updates, e.g. by computing the median +of multiple measurements between the same clocks. + +This effectively reduces update frequency and noise, at the cost of delayed reaction time +to diagnostic state changes. +""" + +import statistics +from collections import defaultdict +from typing import Iterable + +from sync_tooling_msgs.clock_diff_measurement_pb2 import ClockDiffMeasurement +from sync_tooling_msgs.graph_update_pb2 import GraphUpdate + + +def aggregate_clock_diff_measurements( + updates: Iterable[GraphUpdate], +) -> list[GraphUpdate]: + """ + Aggregate ClockDiffMeasurements by grouping them by (src, dst) pairs. + + For each group, compute the median diff_ns and create a single aggregated measurement. + All measurements in the group are removed from the iterable and replaced with the single + aggregated one. + + Args: + updates: An iterable of GraphUpdate messages (not limited to ClockDiffMeasurements). + + Returns: + The `updates` iterable with all ClockDiffMeasurements grouped by (src, dst) and replaced + with a single aggregated measurement for each group. + Non-measurement updates are preserved as-is. + """ + # Group measurements by (src, dst) pairs + measurement_groups = defaultdict(list) + non_measurement_updates = [] + + for update in updates: + if update.HasField("clock_diff_measurement"): + measurement = update.clock_diff_measurement + # Use a tuple of (src, dst) as the grouping key + key = (measurement.src, measurement.dst) + measurement_groups[key].append(measurement) + else: + # Keep non-measurement updates as-is + non_measurement_updates.append(update) + + # Create aggregated measurements + aggregated_updates = [] + for (src, dst), measurements in measurement_groups.items(): + if len(measurements) == 1: + # If only one measurement, no aggregation needed + aggregated_updates.append( + GraphUpdate(clock_diff_measurement=measurements[0]) + ) + else: + # Compute median of diff_ns values + diff_values = [m.diff_ns for m in measurements] + median_diff = statistics.median(diff_values) + + # Create aggregated measurement with median value + aggregated_measurement = ClockDiffMeasurement( + src=src, dst=dst, diff_ns=int(median_diff) + ) + aggregated_updates.append( + GraphUpdate(clock_diff_measurement=aggregated_measurement) + ) + + # Combine non-measurement updates with aggregated measurements + return non_measurement_updates + aggregated_updates diff --git a/sync_graph/tests/test_update_aggregator.py b/sync_graph/tests/test_update_aggregator.py new file mode 100644 index 0000000..bb99754 --- /dev/null +++ b/sync_graph/tests/test_update_aggregator.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +"""Simple test script for the measurement aggregator function.""" + +from sync_graph.update_aggregator import aggregate_clock_diff_measurements +from sync_tooling_msgs.clock_diff_measurement_pb2 import ClockDiffMeasurement +from sync_tooling_msgs.clock_master_update_pb2 import ClockMasterUpdate + +from .util import _gu + + +def test_aggregation(sample_clock, nic_clock, remote_clock): + a = sample_clock + b = nic_clock + c = remote_clock + + # Create test measurements + measurements = [ + # Multiple measurements between a -> b + _gu(ClockDiffMeasurement(src=a, dst=b, diff_ns=1000)), + _gu(ClockDiffMeasurement(src=a, dst=b, diff_ns=2000)), + _gu(ClockDiffMeasurement(src=a, dst=b, diff_ns=3000)), + # Single measurement between b -> c + _gu(ClockDiffMeasurement(src=b, dst=c, diff_ns=5000)), + # Multiple measurements between c -> a + _gu(ClockDiffMeasurement(src=c, dst=a, diff_ns=10000)), + _gu(ClockDiffMeasurement(src=c, dst=a, diff_ns=20000)), + ] + + # Test the aggregation + result = aggregate_clock_diff_measurements(measurements) + + # Verify results + measurement_updates = [u for u in result if u.HasField("clock_diff_measurement")] + assert ( + len(measurement_updates) == 3 + ), f"Expected 3 aggregated measurements, got {len(measurement_updates)}" + + # Check that a -> b has median value of 2000 (median of [1000, 2000, 3000]) + a_to_b = next( + u + for u in measurement_updates + if u.clock_diff_measurement.src == a and u.clock_diff_measurement.dst == b + ) + assert ( + a_to_b.clock_diff_measurement.diff_ns == 2000 + ), f"Expected 2000, got {a_to_b.clock_diff_measurement.diff_ns}" + + # Check that b -> c remains unchanged (single measurement) + b_to_c = next( + u + for u in measurement_updates + if u.clock_diff_measurement.src == b and u.clock_diff_measurement.dst == c + ) + assert ( + b_to_c.clock_diff_measurement.diff_ns == 5000 + ), f"Expected 5000, got {b_to_c.clock_diff_measurement.diff_ns}" + + # Check that c -> a has median value of 15000 (median of [10000, 20000]) + c_to_a = next( + u + for u in measurement_updates + if u.clock_diff_measurement.src == c and u.clock_diff_measurement.dst == a + ) + assert ( + c_to_a.clock_diff_measurement.diff_ns == 15000 + ), f"Expected 15000, got {c_to_a.clock_diff_measurement.diff_ns}" + + +def test_other_updates_untouched(sample_clock, nic_clock, remote_clock): + a = sample_clock + b = nic_clock + c = remote_clock + + other_update = _gu(ClockMasterUpdate(clock_id=c, master=b, master_offset_ns=1000)) + + # Create some other updates that should not be aggregated + updates = [ + _gu(ClockDiffMeasurement(src=a, dst=c, diff_ns=5000)), + other_update, + _gu(ClockDiffMeasurement(src=a, dst=c, diff_ns=3000)), + ] + + # Test the aggregation with other updates included + result = aggregate_clock_diff_measurements(updates) + + # Verify that the other updates remain unchanged + assert len(result) == 2, f"Expected 2 updates, got {len(result)}" + assert other_update in result, "Other update was not preserved"