Skip to content

Commit

Permalink
Correctly handle py3 RawIOBase read()
Browse files Browse the repository at this point in the history
Python3's RawIOBase guarantees only one syscall per read() requiring
a loop to accumulate the desired number of bytes or actually reach EOF.

TextIOBase.read does issue multiple syscalls (it must to correctly decode
partial unicode characters), but subunit unwraps that to get a binary stream,
and at least some of the time the layering is io.TextIOBase(_io.FileIO), where
_io.FileIO is a RawIOBase subclass rather than BufferedIOBase.

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Partial-bug: #1813147
  • Loading branch information
stephenfin committed Mar 14, 2020
1 parent 8fb3e0c commit 26d31fa
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions python/subunit/v2.py
Expand Up @@ -72,13 +72,32 @@ def has_nul(buffer_or_bytes):
return NUL_ELEMENT in buffer_or_bytes


def read_exactly(stream, size):
"""Read exactly size bytes from stream.
:param stream: A file like object to read bytes from. Must support
read(<count>) and return bytes.
:param size: The number of bytes to retrieve.
"""
data = b''
remaining = size
while remaining:
read = stream.read(remaining)
if len(read) == 0:
raise ParseError('Short read - got %d bytes, wanted %d bytes' % (
len(data), size))
data += read
remaining -= len(read)
return data


class ParseError(Exception):
"""Used to pass error messages within the parser."""


class StreamResultToBytes(object):
"""Convert StreamResult API calls to bytes.
The StreamResult API is defined by testtools.StreamResult.
"""

Expand Down Expand Up @@ -276,7 +295,7 @@ def __init__(self, source, non_subunit_name=None):

def run(self, result):
"""Parse source and emit events to result.
This is a blocking call: it will run until EOF is detected on source.
"""
self.codec.reset()
Expand Down Expand Up @@ -406,21 +425,12 @@ def _parse_varint(self, data, pos, max_3_bytes=False):

def _parse(self, packet, result):
# 2 bytes flags, at most 3 bytes length.
packet.append(self.source.read(5))
if len(packet[-1]) != 5:
raise ParseError(
'Short read - got %d bytes, wanted 5' % len(packet[-1]))

flag_bytes = packet[-1][:2]
flags = struct.unpack(FMT_16, flag_bytes)[0]
length, consumed = self._parse_varint(
packet[-1], 2, max_3_bytes=True)
remainder = self.source.read(length - 6)
if len(remainder) != length - 6:
raise ParseError(
'Short read - got %d bytes, wanted %d bytes' % (
len(remainder), length - 6))
header = read_exactly(self.source, 5)
packet.append(header)
flags = struct.unpack(FMT_16, header[:2])[0]
length, consumed = self._parse_varint(header, 2, max_3_bytes=True)

remainder = read_exactly(self.source, length - 6)
if consumed != 3:
# Avoid having to parse torn values
packet[-1] += remainder
Expand Down Expand Up @@ -533,4 +543,3 @@ def _read_utf8(self, buf, pos):
return utf8, length+pos
except UnicodeDecodeError:
raise ParseError('UTF8 string at offset %d is not UTF8' % (pos-2,))

0 comments on commit 26d31fa

Please sign in to comment.