diff --git a/python/subunit/v2.py b/python/subunit/v2.py index c2c63f66..e8a31d69 100644 --- a/python/subunit/v2.py +++ b/python/subunit/v2.py @@ -72,13 +72,32 @@ def has_nul(buffer_or_bytes): return NUL_ELEMENT in buffer_or_bytes +def read_exactly(stream, size): + """Read exactly size bytes from stream. + + :param stream: A file like object to read bytes from. Must support + read() and return bytes. + :param size: The number of bytes to retrieve. + """ + data = b'' + remaining = size + while remaining: + read = stream.read(remaining) + if len(read) == 0: + raise ParseError('Short read - got %d bytes, wanted %d bytes' % ( + len(data), size)) + data += read + remaining -= len(read) + return data + + class ParseError(Exception): """Used to pass error messages within the parser.""" class StreamResultToBytes(object): """Convert StreamResult API calls to bytes. - + The StreamResult API is defined by testtools.StreamResult. """ @@ -276,7 +295,7 @@ def __init__(self, source, non_subunit_name=None): def run(self, result): """Parse source and emit events to result. - + This is a blocking call: it will run until EOF is detected on source. """ self.codec.reset() @@ -406,21 +425,12 @@ def _parse_varint(self, data, pos, max_3_bytes=False): def _parse(self, packet, result): # 2 bytes flags, at most 3 bytes length. - packet.append(self.source.read(5)) - if len(packet[-1]) != 5: - raise ParseError( - 'Short read - got %d bytes, wanted 5' % len(packet[-1])) - - flag_bytes = packet[-1][:2] - flags = struct.unpack(FMT_16, flag_bytes)[0] - length, consumed = self._parse_varint( - packet[-1], 2, max_3_bytes=True) - remainder = self.source.read(length - 6) - if len(remainder) != length - 6: - raise ParseError( - 'Short read - got %d bytes, wanted %d bytes' % ( - len(remainder), length - 6)) + header = read_exactly(self.source, 5) + packet.append(header) + flags = struct.unpack(FMT_16, header[:2])[0] + length, consumed = self._parse_varint(header, 2, max_3_bytes=True) + remainder = read_exactly(self.source, length - 6) if consumed != 3: # Avoid having to parse torn values packet[-1] += remainder @@ -533,4 +543,3 @@ def _read_utf8(self, buf, pos): return utf8, length+pos except UnicodeDecodeError: raise ParseError('UTF8 string at offset %d is not UTF8' % (pos-2,)) -