diff --git a/ros2doctor/ros2doctor/verb/hello.py b/ros2doctor/ros2doctor/verb/hello.py index 74f7cff7c..acafb6fa0 100644 --- a/ros2doctor/ros2doctor/verb/hello.py +++ b/ros2doctor/ros2doctor/verb/hello.py @@ -18,13 +18,13 @@ import socket import struct import threading -import time import rclpy +from rclpy.duration import Duration from rclpy.executors import SingleThreadedExecutor -from rclpy.node import Node +from rclpy.qos import qos_profile_system_default +from ros2cli.node.direct import DirectNode from ros2doctor.verb import VerbExtension - from std_msgs.msg import String DEFAULT_GROUP = '225.0.0.1' @@ -34,14 +34,20 @@ f"ros2doctor_{re.sub(r'[^0-9a-zA-Z_]', '_', socket.gethostname())}_{os.getpid()}" -def positive_int(string: str) -> int: - try: - value = int(string) - except ValueError: - value = -1 - if value <= 0: - raise ArgumentTypeError('value must be a positive integer') - return value +def positive(type_): + def _coerce(string): + try: + value = type_(string) + except ValueError: + value = -1 + if value <= 0: + raise ArgumentTypeError('value must be a positive {type_.__name__}') + return value + return _coerce + + +positive_float = positive(float) +positive_int = positive(int) class HelloVerb(VerbExtension): @@ -59,10 +65,10 @@ def add_arguments(self, parser, cli_name): '-t', '--topic', nargs='?', default='/canyouhearme', help="Name of ROS topic to publish to (default: '/canyouhearme')") parser.add_argument( - '-ep', '--emit-period', metavar='N', type=float, default=0.1, + '-ep', '--emit-period', metavar='N', type=positive_float, default=0.1, help='Time period to publish/send one message (default: 0.1s)') parser.add_argument( - '-pp', '--print-period', metavar='N', type=float, default=1.0, + '-pp', '--print-period', metavar='N', type=positive_float, default=1.0, help='Time period to print summary table (default: 1.0s)') parser.add_argument( '--ttl', type=positive_int, @@ -71,125 +77,166 @@ def add_arguments(self, parser, cli_name): '-1', '--once', action='store_true', default=False, help='Publish and multicast send for one emit period then exit; used in test case.') - def main(self, *, args): - global summary_table - summary_table = SummaryTable() - rclpy.init() - executor = SingleThreadedExecutor() - pub_node = Talker(args.topic, args.emit_period) - sub_node = Listener(args.topic) - executor.add_node(pub_node) - executor.add_node(sub_node) - try: - prev_time = time.time() - # pub/sub thread - exec_thread = threading.Thread(target=executor.spin) - exec_thread.start() - while True: - if (time.time() - prev_time > args.print_period): - summary_table.format_print_summary(args.topic, args.print_period) - summary_table.reset() - prev_time = time.time() - # multicast threads - send_thread = threading.Thread(target=_send, kwargs={'ttl': args.ttl}) - send_thread.daemon = True - receive_thread = threading.Thread(target=_receive) - receive_thread.daemon = True - receive_thread.start() - send_thread.start() - time.sleep(args.emit_period) - if args.once: - return summary_table - except KeyboardInterrupt: - pass - finally: - executor.shutdown() - rclpy.shutdown() - pub_node.destroy_node() - sub_node.destroy_node() - - -class Talker(Node): - """Initialize talker node.""" - - def __init__(self, topic, time_period, *, qos=10): - node_name = NODE_NAME_PREFIX + '_talker' - super().__init__(node_name) - self._i = 0 - self._pub = self.create_publisher(String, topic, qos) - self._timer = self.create_timer(time_period, self.timer_callback) - - def timer_callback(self): + def main(self, *, args, summary_table=None): + if summary_table is None: + summary_table = SummaryTable() + with DirectNode(args, node_name=NODE_NAME_PREFIX + '_node') as node: + publisher = HelloPublisher(node, args.topic, summary_table) + subscriber = HelloSubscriber(node, args.topic, summary_table) + sender = HelloMulticastUDPSender(summary_table, ttl=args.ttl) + receiver = HelloMulticastUDPReceiver(summary_table) + receiver_thread = threading.Thread(target=receiver.recv) + receiver_thread.start() + + executor = SingleThreadedExecutor() + executor.add_node(node.node) + + executor_thread = threading.Thread(target=executor.spin) + executor_thread.start() + + try: + clock = node.get_clock() + prev_time = clock.now() + print_period = Duration(seconds=args.print_period) + emit_rate = node.create_rate(frequency=1.0 / args.emit_period, clock=clock) + while rclpy.ok(): + current_time = clock.now() + if (current_time - prev_time) > print_period: + summary_table.format_print_summary(args.topic, args.print_period) + summary_table.reset() + prev_time = current_time + publisher.publish() + sender.send() + emit_rate.sleep() + if args.once: + summary_table.format_print_summary(args.topic, args.print_period) + break + except KeyboardInterrupt: + pass + finally: + executor.shutdown() + executor_thread.join() + receiver.shutdown() + receiver_thread.join() + sender.shutdown() + subscriber.destroy() + publisher.destroy() + + +class HelloPublisher: + """Publish 'hello' messages over an std_msgs/msg/String topic.""" + + def __init__(self, node, topic, summary_table, *, qos=qos_profile_system_default): + self._summary_table = summary_table + self._pub = node.create_publisher(String, topic, qos) + + def destroy(self): + self._pub.destroy() + + def publish(self): msg = String() hostname = socket.gethostname() msg.data = f"hello, it's me {hostname}" - summary_table.increment_pub() + self._summary_table.increment_pub() self._pub.publish(msg) - self._i += 1 -class Listener(Node): - """Initialize listener node.""" +class HelloSubscriber: + """Subscribe to 'hello' messages over an std_msgs/msg/String topic.""" + + def __init__(self, node, topic, summary_table, *, qos=qos_profile_system_default): + self._summary_table = summary_table + self._sub = node.create_subscription(String, topic, self._callback, qos) - def __init__(self, topic, *, qos=10): - node_name = NODE_NAME_PREFIX + '_listener' - super().__init__(node_name) - self._sub = self.create_subscription( - String, - topic, - self.sub_callback, - qos) + def destroy(self): + self._sub.destroy() - def sub_callback(self, msg): + def _callback(self, msg): msg_data = msg.data.split() pub_hostname = msg_data[-1] if pub_hostname != socket.gethostname(): - summary_table.increment_sub(pub_hostname) - - -def _send(*, group=DEFAULT_GROUP, port=DEFAULT_PORT, ttl=None): - """Multicast send one message.""" - hostname = socket.gethostname() - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) - if ttl is not None: - packed_ttl = struct.pack('b', ttl) - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, packed_ttl) - try: - s.sendto(f"hello, it's me {hostname}".encode('utf-8'), (group, port)) - summary_table.increment_send() - finally: - s.close() - - -def _receive(*, group=DEFAULT_GROUP, port=DEFAULT_PORT, timeout=None): - """Multicast receive.""" - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) - try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._summary_table.increment_sub(pub_hostname) + + +class HelloMulticastUDPSender: + """Send 'hello' messages over a multicast UDP socket.""" + + def __init__(self, summary_table, group=DEFAULT_GROUP, port=DEFAULT_PORT, ttl=None): + self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except AttributeError: - # not available on Windows - pass - s.bind(('', port)) + if ttl is not None: + packed_ttl = struct.pack('b', ttl) + self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, packed_ttl) + except Exception: + self._socket.close() + raise + self._summary_table = summary_table + self._group = group + self._port = port + + def send(self): + hostname = socket.gethostname() + self._socket.sendto( + f"hello, it's me {hostname}".encode('utf-8'), (self._group, self._port) + ) + self._summary_table.increment_send() + + def shutdown(self): + self._socket.close() + - s.settimeout(timeout) +class HelloMulticastUDPReceiver: + """Receive 'hello' messages over a multicast UDP socket.""" - mreq = struct.pack('4sl', socket.inet_aton(group), socket.INADDR_ANY) - s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) + def __init__(self, summary_table, group=DEFAULT_GROUP, port=DEFAULT_PORT, timeout=None): + self._dummy_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) try: - data, _ = s.recvfrom(4096) - data = data.decode('utf-8') - sender_hostname = data.split()[-1] - if sender_hostname != socket.gethostname(): - summary_table.increment_receive(sender_hostname) - finally: - s.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, mreq) - finally: - s.close() - - -class SummaryTable(): + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except AttributeError: + # not available on Windows + pass + self._socket.bind(('', port)) + + self._socket.settimeout(timeout) + + self._mreq = struct.pack('4sl', socket.inet_aton(group), socket.INADDR_ANY) + self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, self._mreq) + except Exception: + self._dummy_socket.close() + self._socket.close() + raise + self._is_shutdown = False + self._summary_table = summary_table + self._group = group + self._port = port + + def recv(self): + try: + while not self._is_shutdown: + data, _ = self._socket.recvfrom(4096) + data = data.decode('utf-8') + sender_hostname = data.split()[-1] + if sender_hostname != socket.gethostname(): + self._summary_table.increment_receive(sender_hostname) + except socket.timeout: + pass + + def shutdown(self): + if self._is_shutdown: + return + self._is_shutdown = True + self._dummy_socket.sendto( + f'{socket.gethostname()}'.encode('utf-8'), ('127.0.0.1', self._port) + ) + self._dummy_socket.close() + self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, self._mreq) + self._socket.close() + + +class SummaryTable: """Summarize number of msgs published/sent and subscribed/received.""" def __init__(self): diff --git a/ros2doctor/test/test_cli.py b/ros2doctor/test/test_cli.py index 965bf90ad..6f173477a 100644 --- a/ros2doctor/test/test_cli.py +++ b/ros2doctor/test/test_cli.py @@ -62,8 +62,9 @@ def test_hello_single_host(self): args.ttl = None args.once = True with mock.patch('socket.gethostname', return_value='!nv@lid-n*de-n4me'): + summary = SummaryTable() hello_verb = HelloVerb() - summary = hello_verb.main(args=args) + hello_verb.main(args=args, summary_table=summary) expected_summary = _generate_expected_summary_table() self.assertEqual(summary._pub, expected_summary._pub) self.assertEqual(summary._sub, expected_summary._sub)