Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ros2doctor hello verb. #521

Merged
merged 3 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 164 additions & 117 deletions ros2doctor/ros2doctor/verb/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import socket
import struct
import threading
import time

import rclpy
from rclpy.duration import Duration
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from rclpy.qos import qos_profile_system_default
from ros2cli.node.direct import DirectNode
from ros2doctor.verb import VerbExtension

from std_msgs.msg import String

DEFAULT_GROUP = '225.0.0.1'
Expand All @@ -34,14 +34,20 @@
f"ros2doctor_{re.sub(r'[^0-9a-zA-Z_]', '_', socket.gethostname())}_{os.getpid()}"


def positive_int(string: str) -> int:
try:
value = int(string)
except ValueError:
value = -1
if value <= 0:
raise ArgumentTypeError('value must be a positive integer')
return value
def positive(type_):
def _coerce(string):
try:
value = type_(string)
except ValueError:
value = -1
if value <= 0:
raise ArgumentTypeError('value must be a positive {type_.__name__}')
return value
return _coerce


positive_float = positive(float)
positive_int = positive(int)


class HelloVerb(VerbExtension):
Expand All @@ -59,10 +65,10 @@ def add_arguments(self, parser, cli_name):
'-t', '--topic', nargs='?', default='/canyouhearme',
help="Name of ROS topic to publish to (default: '/canyouhearme')")
parser.add_argument(
'-ep', '--emit-period', metavar='N', type=float, default=0.1,
'-ep', '--emit-period', metavar='N', type=positive_float, default=0.1,
help='Time period to publish/send one message (default: 0.1s)')
parser.add_argument(
'-pp', '--print-period', metavar='N', type=float, default=1.0,
'-pp', '--print-period', metavar='N', type=positive_float, default=1.0,
help='Time period to print summary table (default: 1.0s)')
parser.add_argument(
'--ttl', type=positive_int,
Expand All @@ -71,125 +77,166 @@ def add_arguments(self, parser, cli_name):
'-1', '--once', action='store_true', default=False,
help='Publish and multicast send for one emit period then exit; used in test case.')

def main(self, *, args):
global summary_table
summary_table = SummaryTable()
rclpy.init()
executor = SingleThreadedExecutor()
pub_node = Talker(args.topic, args.emit_period)
sub_node = Listener(args.topic)
executor.add_node(pub_node)
executor.add_node(sub_node)
try:
prev_time = time.time()
# pub/sub thread
exec_thread = threading.Thread(target=executor.spin)
exec_thread.start()
while True:
if (time.time() - prev_time > args.print_period):
summary_table.format_print_summary(args.topic, args.print_period)
summary_table.reset()
prev_time = time.time()
# multicast threads
send_thread = threading.Thread(target=_send, kwargs={'ttl': args.ttl})
send_thread.daemon = True
receive_thread = threading.Thread(target=_receive)
receive_thread.daemon = True
receive_thread.start()
send_thread.start()
time.sleep(args.emit_period)
if args.once:
return summary_table
except KeyboardInterrupt:
pass
finally:
executor.shutdown()
rclpy.shutdown()
pub_node.destroy_node()
sub_node.destroy_node()


class Talker(Node):
"""Initialize talker node."""

def __init__(self, topic, time_period, *, qos=10):
node_name = NODE_NAME_PREFIX + '_talker'
super().__init__(node_name)
self._i = 0
self._pub = self.create_publisher(String, topic, qos)
self._timer = self.create_timer(time_period, self.timer_callback)

def timer_callback(self):
def main(self, *, args, summary_table=None):
if summary_table is None:
summary_table = SummaryTable()
with DirectNode(args, node_name=NODE_NAME_PREFIX + '_node') as node:
publisher = HelloPublisher(node, args.topic, summary_table)
subscriber = HelloSubscriber(node, args.topic, summary_table)
sender = HelloMulticastUDPSender(summary_table, ttl=args.ttl)
receiver = HelloMulticastUDPReceiver(summary_table)
receiver_thread = threading.Thread(target=receiver.recv)
receiver_thread.start()

executor = SingleThreadedExecutor()
executor.add_node(node.node)

executor_thread = threading.Thread(target=executor.spin)
executor_thread.start()

try:
clock = node.get_clock()
prev_time = clock.now()
print_period = Duration(seconds=args.print_period)
emit_rate = node.create_rate(frequency=1.0 / args.emit_period, clock=clock)
while rclpy.ok():
current_time = clock.now()
if (current_time - prev_time) > print_period:
summary_table.format_print_summary(args.topic, args.print_period)
summary_table.reset()
prev_time = current_time
publisher.publish()
sender.send()
emit_rate.sleep()
if args.once:
summary_table.format_print_summary(args.topic, args.print_period)
break
except KeyboardInterrupt:
pass
finally:
executor.shutdown()
executor_thread.join()
receiver.shutdown()
receiver_thread.join()
sender.shutdown()
subscriber.destroy()
publisher.destroy()


class HelloPublisher:
"""Publish 'hello' messages over an std_msgs/msg/String topic."""

def __init__(self, node, topic, summary_table, *, qos=qos_profile_system_default):
self._summary_table = summary_table
self._pub = node.create_publisher(String, topic, qos)

def destroy(self):
self._pub.destroy()

def publish(self):
msg = String()
hostname = socket.gethostname()
msg.data = f"hello, it's me {hostname}"
summary_table.increment_pub()
self._summary_table.increment_pub()
self._pub.publish(msg)
self._i += 1


class Listener(Node):
"""Initialize listener node."""
class HelloSubscriber:
"""Subscribe to 'hello' messages over an std_msgs/msg/String topic."""

def __init__(self, node, topic, summary_table, *, qos=qos_profile_system_default):
self._summary_table = summary_table
self._sub = node.create_subscription(String, topic, self._callback, qos)

def __init__(self, topic, *, qos=10):
node_name = NODE_NAME_PREFIX + '_listener'
super().__init__(node_name)
self._sub = self.create_subscription(
String,
topic,
self.sub_callback,
qos)
def destroy(self):
self._sub.destroy()

def sub_callback(self, msg):
def _callback(self, msg):
msg_data = msg.data.split()
pub_hostname = msg_data[-1]
if pub_hostname != socket.gethostname():
summary_table.increment_sub(pub_hostname)


def _send(*, group=DEFAULT_GROUP, port=DEFAULT_PORT, ttl=None):
"""Multicast send one message."""
hostname = socket.gethostname()
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
if ttl is not None:
packed_ttl = struct.pack('b', ttl)
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, packed_ttl)
try:
s.sendto(f"hello, it's me {hostname}".encode('utf-8'), (group, port))
summary_table.increment_send()
finally:
s.close()


def _receive(*, group=DEFAULT_GROUP, port=DEFAULT_PORT, timeout=None):
"""Multicast receive."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._summary_table.increment_sub(pub_hostname)


class HelloMulticastUDPSender:
"""Send 'hello' messages over a multicast UDP socket."""

def __init__(self, summary_table, group=DEFAULT_GROUP, port=DEFAULT_PORT, ttl=None):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
# not available on Windows
pass
s.bind(('', port))
if ttl is not None:
packed_ttl = struct.pack('b', ttl)
self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, packed_ttl)
except Exception:
self._socket.close()
raise
self._summary_table = summary_table
self._group = group
self._port = port

def send(self):
hostname = socket.gethostname()
self._socket.sendto(
f"hello, it's me {hostname}".encode('utf-8'), (self._group, self._port)
)
self._summary_table.increment_send()

def shutdown(self):
self._socket.close()


s.settimeout(timeout)
class HelloMulticastUDPReceiver:
"""Receive 'hello' messages over a multicast UDP socket."""

mreq = struct.pack('4sl', socket.inet_aton(group), socket.INADDR_ANY)
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
def __init__(self, summary_table, group=DEFAULT_GROUP, port=DEFAULT_PORT, timeout=None):
self._dummy_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
try:
data, _ = s.recvfrom(4096)
data = data.decode('utf-8')
sender_hostname = data.split()[-1]
if sender_hostname != socket.gethostname():
summary_table.increment_receive(sender_hostname)
finally:
s.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, mreq)
finally:
s.close()


class SummaryTable():
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
# not available on Windows
pass
self._socket.bind(('', port))

self._socket.settimeout(timeout)

self._mreq = struct.pack('4sl', socket.inet_aton(group), socket.INADDR_ANY)
self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, self._mreq)
except Exception:
self._dummy_socket.close()
self._socket.close()
raise
self._is_shutdown = False
self._summary_table = summary_table
self._group = group
self._port = port

def recv(self):
try:
while not self._is_shutdown:
data, _ = self._socket.recvfrom(4096)
data = data.decode('utf-8')
sender_hostname = data.split()[-1]
if sender_hostname != socket.gethostname():
self._summary_table.increment_receive(sender_hostname)
except socket.timeout:
pass

def shutdown(self):
if self._is_shutdown:
return
self._is_shutdown = True
self._dummy_socket.sendto(
f'{socket.gethostname()}'.encode('utf-8'), ('127.0.0.1', self._port)
)
self._dummy_socket.close()
self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP, self._mreq)
self._socket.close()


class SummaryTable:
"""Summarize number of msgs published/sent and subscribed/received."""

def __init__(self):
Expand Down
3 changes: 2 additions & 1 deletion ros2doctor/test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def test_hello_single_host(self):
args.ttl = None
args.once = True
with mock.patch('socket.gethostname', return_value='!nv@lid-n*de-n4me'):
summary = SummaryTable()
hello_verb = HelloVerb()
summary = hello_verb.main(args=args)
hello_verb.main(args=args, summary_table=summary)
expected_summary = _generate_expected_summary_table()
self.assertEqual(summary._pub, expected_summary._pub)
self.assertEqual(summary._sub, expected_summary._sub)
Expand Down