diff --git a/tools/rosbag/src/rosbag/bag.py b/tools/rosbag/src/rosbag/bag.py index dfa18499cb..dbe26a4aa3 100644 --- a/tools/rosbag/src/rosbag/bag.py +++ b/tools/rosbag/src/rosbag/bag.py @@ -122,6 +122,7 @@ class Compression: LZ4 = 'lz4' BagMessage = collections.namedtuple('BagMessage', 'topic message timestamp') +BagMessageWithConnectionHeader = collections.namedtuple('BagMessageWithConnectionHeader', 'topic message timestamp connection_header') class _ROSBagEncryptor(object): """ @@ -534,7 +535,7 @@ def _set_chunk_threshold(self, chunk_threshold): chunk_threshold = property(_get_chunk_threshold, _set_chunk_threshold) - def read_messages(self, topics=None, start_time=None, end_time=None, connection_filter=None, raw=False): + def read_messages(self, topics=None, start_time=None, end_time=None, connection_filter=None, raw=False, return_connection_header=False): """ Read messages from the bag, optionally filtered by topic, timestamp and connection details. @param topics: list of topics or a single topic. if an empty list is given all topics will be read [optional] @@ -555,7 +556,7 @@ def read_messages(self, topics=None, start_time=None, end_time=None, connection_ if topics and type(topics) is str: topics = [topics] - return self._reader.read_messages(topics, start_time, end_time, connection_filter, raw) + return self._reader.read_messages(topics, start_time, end_time, connection_filter, raw, return_connection_header) def flush(self): """ @@ -568,7 +569,7 @@ def flush(self): if self._chunk_open: self._stop_writing_chunk() - def write(self, topic, msg, t=None, raw=False): + def write(self, topic, msg, t=None, raw=False, connection_header=None): """ Write a message to the bag. @param topic: name of topic @@ -627,10 +628,20 @@ def write(self, topic, msg, t=None, raw=False): if pytype._md5sum != md5sum: print('WARNING: md5sum of loaded type [%s] does not match that specified' % msg_type, file=sys.stderr) #raise ROSBagException('md5sum of loaded type does not match that of data being recorded') - - header = { 'topic' : topic, 'type' : msg_type, 'md5sum' : md5sum, 'message_definition' : pytype._full_text } + + header = connection_header if connection_header is not None else { + 'topic': topic, + 'type': msg_type, + 'md5sum': md5sum, + 'message_definition': pytype._full_text + } else: - header = { 'topic' : topic, 'type' : msg.__class__._type, 'md5sum' : msg.__class__._md5sum, 'message_definition' : msg._full_text } + header = connection_header if connection_header is not None else { + 'topic': topic, + 'type': msg.__class__._type, + 'md5sum': msg.__class__._md5sum, + 'message_definition': msg._full_text + } connection_info = _ConnectionInfo(conn_id, topic, header) # No need to encrypt connection records in chunk (encrypt=False) @@ -2041,7 +2052,7 @@ def __init__(self, bag): def start_reading(self): raise NotImplementedError() - def read_messages(self, topics, start_time, end_time, connection_filter, raw): + def read_messages(self, topics, start_time, end_time, connection_filter, raw, return_connection_header): raise NotImplementedError() def reindex(self): @@ -2106,7 +2117,7 @@ def reindex(self): offset = f.tell() - def read_messages(self, topics, start_time, end_time, topic_filter, raw): + def read_messages(self, topics, start_time, end_time, topic_filter, raw, return_connection_header): f = self.bag._file f.seek(self.bag._file_header_pos) @@ -2166,7 +2177,10 @@ def read_messages(self, topics, start_time, end_time, topic_filter, raw): msg = msg_type() msg.deserialize(data) - yield BagMessage(topic, msg, t) + if return_connection_header: + yield BagMessageWithConnectionHeader(topic, msg, t, info.header) + else: + yield BagMessage(topic, msg, t) self.bag._connection_indexes_read = True @@ -2656,10 +2670,10 @@ def _read_connection_index_records(self): self.bag._connection_indexes_read = True - def read_messages(self, topics, start_time, end_time, connection_filter, raw): + def read_messages(self, topics, start_time, end_time, connection_filter, raw, return_connection_header): connections = self.bag._get_connections(topics, connection_filter) for entry in self.bag._get_entries(connections, start_time, end_time): - yield self.seek_and_read_message_data_record((entry.chunk_pos, entry.offset), raw) + yield self.seek_and_read_message_data_record((entry.chunk_pos, entry.offset), raw, return_connection_header) ### @@ -2760,7 +2774,7 @@ def read_connection_index_record(self): return (connection_id, index) - def seek_and_read_message_data_record(self, position, raw): + def seek_and_read_message_data_record(self, position, raw, return_connection_header): chunk_pos, offset = position chunk_header = self.bag._chunk_headers.get(chunk_pos) @@ -2835,8 +2849,11 @@ def seek_and_read_message_data_record(self, position, raw): else: msg = msg_type() msg.deserialize(data) - - return BagMessage(connection_info.topic, msg, t) + + if return_connection_header: + return BagMessageWithConnectionHeader(connection_info.topic, msg, t, connection_info.header) + else: + return BagMessage(connection_info.topic, msg, t) def _time_to_str(secs): secs_frac = secs - int(secs) diff --git a/tools/rosbag/src/rosbag/rosbag_main.py b/tools/rosbag/src/rosbag/rosbag_main.py index 34f5351ce8..5f5fe5869c 100644 --- a/tools/rosbag/src/rosbag/rosbag_main.py +++ b/tools/rosbag/src/rosbag/rosbag_main.py @@ -355,27 +355,27 @@ def eval_fn(topic, m, t): if options.verbose_pattern: verbose_pattern = expr_eval(options.verbose_pattern) - for topic, raw_msg, t in inbag.read_messages(raw=True): + for topic, raw_msg, t, conn_header in inbag.read_messages(raw=True, return_connection_header=True): msg_type, serialized_bytes, md5sum, pos, pytype = raw_msg msg = pytype() msg.deserialize(serialized_bytes) if filter_fn(topic, msg, t): print('MATCH', verbose_pattern(topic, msg, t)) - outbag.write(topic, msg, t) + outbag.write(topic, msg, t, connection_header=conn_header) else: print('NO MATCH', verbose_pattern(topic, msg, t)) total_bytes += len(serialized_bytes) meter.step(total_bytes) else: - for topic, raw_msg, t in inbag.read_messages(raw=True): + for topic, raw_msg, t, conn_header in inbag.read_messages(raw=True, return_connection_header=True): msg_type, serialized_bytes, md5sum, pos, pytype = raw_msg msg = pytype() msg.deserialize(serialized_bytes) if filter_fn(topic, msg, t): - outbag.write(topic, msg, t) + outbag.write(topic, msg, t, connection_header=conn_header) total_bytes += len(serialized_bytes) meter.step(total_bytes) @@ -764,16 +764,16 @@ def change_compression_op(inbag, outbag, compression, quiet): outbag.compression = compression if quiet: - for topic, msg, t in inbag.read_messages(raw=True): - outbag.write(topic, msg, t, raw=True) + for topic, msg, t, conn_header in inbag.read_messages(raw=True, return_connection_header=True): + outbag.write(topic, msg, t, raw=True, connection_header=conn_header) else: meter = ProgressMeter(outbag.filename, inbag._uncompressed_size) total_bytes = 0 - for topic, msg, t in inbag.read_messages(raw=True): + for topic, msg, t, conn_header in inbag.read_messages(raw=True, return_connection_header=True): msg_type, serialized_bytes, md5sum, pos, pytype = msg - outbag.write(topic, msg, t, raw=True) + outbag.write(topic, msg, t, raw=True, connection_header=conn_header) total_bytes += len(serialized_bytes) meter.step(total_bytes) @@ -789,8 +789,8 @@ def reindex_op(inbag, outbag, quiet): except: pass - for (topic, msg, t) in inbag.read_messages(): - outbag.write(topic, msg, t) + for (topic, msg, t, conn_header) in inbag.read_messages(return_connection_header=True): + outbag.write(topic, msg, t, connection_header=conn_header) else: meter = ProgressMeter(outbag.filename, inbag.size) try: @@ -801,8 +801,8 @@ def reindex_op(inbag, outbag, quiet): meter.finish() meter = ProgressMeter(outbag.filename, inbag.size) - for (topic, msg, t) in inbag.read_messages(): - outbag.write(topic, msg, t) + for (topic, msg, t, conn_header) in inbag.read_messages(return_connection_header=True): + outbag.write(topic, msg, t, connection_header=conn_header) meter.step(inbag._file.tell()) meter.finish() else: