Skip to content

Commit

Permalink
bpo-31855: unittest.mock.mock_open() results now respects the argumen…
Browse files Browse the repository at this point in the history
…t of read([size]) (GH-11521)

unittest.mock.mock_open() results now respects the argument of read([size])

Co-Authored-By: remilapeyre <remi.lapeyre@henki.fr>
Backports: 11a8832c98b3db78727312154dd1d3ba76d639ec
Signed-off-by: Chris Withers <chris@simplistix.co.uk>
  • Loading branch information
Rémi Lapeyre authored and cjw296 committed May 7, 2019
1 parent 74f6a7e commit 4bd71fe
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
2 changes: 2 additions & 0 deletions NEWS.d/2019-01-11-17-09-15.bpo-31855.PlhfsX.rst
@@ -0,0 +1,2 @@
:func:`unittest.mock.mock_open` results now respects the argument of read([size]).
Patch contributed by Rémi Lapeyre.
38 changes: 14 additions & 24 deletions mock/mock.py
Expand Up @@ -55,6 +55,7 @@


from functools import partial
import io
import inspect
import pprint
import sys
Expand Down Expand Up @@ -2493,25 +2494,13 @@ def __init__(self, spec, spec_set=False, parent=None,

file_spec = None

def _iterate_read_data(read_data):
# Helper for mock_open:
# Retrieve lines from read_data via a generator so that separate calls to
# readline, read, and readlines are properly interleaved
sep = b'\n' if isinstance(read_data, bytes) else '\n'
data_as_list = [l + sep for l in read_data.split(sep)]

if data_as_list[-1] == sep:
# If the last line ended in a newline, the list comprehension will have an
# extra entry that's just a newline. Remove this.
data_as_list = data_as_list[:-1]

def _to_stream(read_data):
if isinstance(read_data, bytes):
return io.BytesIO(read_data)
else:
# If there wasn't an extra newline by itself, then the file being
# emulated doesn't have a newline to end the last line remove the
# newline that our naive format() added
data_as_list[-1] = data_as_list[-1][:-1]
return io.StringIO(read_data)

for line in data_as_list:
yield line

def mock_open(mock=None, read_data=''):
"""
Expand All @@ -2525,21 +2514,24 @@ def mock_open(mock=None, read_data=''):
`read_data` is a string for the `read`, `readline` and `readlines` of the
file handle to return. This is an empty string by default.
"""
_read_data = _to_stream(read_data)
_state = [_read_data, None]

def _readlines_side_effect(*args, **kwargs):
if handle.readlines.return_value is not None:
return handle.readlines.return_value
return list(_state[0])
return _state[0].readlines(*args, **kwargs)

def _read_side_effect(*args, **kwargs):
if handle.read.return_value is not None:
return handle.read.return_value
return type(read_data)().join(_state[0])
return _state[0].read(*args, **kwargs)

def _readline_side_effect():
def _readline_side_effect(*args, **kwargs):
for item in _iter_side_effect():
yield item
while True:
yield type(read_data)()
yield _state[0].readline(*args, **kwargs)

def _iter_side_effect():
if handle.readline.return_value is not None:
Expand All @@ -2563,8 +2555,6 @@ def _iter_side_effect():
handle = MagicMock(spec=file_spec)
handle.__enter__.return_value = handle

_state = [_iterate_read_data(read_data), None]

handle.write.return_value = None
handle.read.return_value = None
handle.readline.return_value = None
Expand All @@ -2577,7 +2567,7 @@ def _iter_side_effect():
handle.__iter__.side_effect = _iter_side_effect

def reset_data(*args, **kwargs):
_state[0] = _iterate_read_data(read_data)
_state[0] = _to_stream(read_data)
if handle.readline.side_effect == _state[1]:
# Only reset the side effect if the user hasn't overridden it.
_state[1] = _readline_side_effect()
Expand Down
7 changes: 6 additions & 1 deletion mock/tests/testwith.py
Expand Up @@ -288,7 +288,12 @@ def test_mock_open_read_with_argument(self):
# for mocks returned by mock_open
some_data = 'foo\nbar\nbaz'
mock = mock_open(read_data=some_data)
self.assertEqual(mock().read(10), some_data)
self.assertEqual(mock().read(10), some_data[:10])
self.assertEqual(mock().read(10), some_data[:10])

f = mock()
self.assertEqual(f.read(10), some_data[:10])
self.assertEqual(f.read(10), some_data[10:])


def test_interleaved_reads(self):
Expand Down

0 comments on commit 4bd71fe

Please sign in to comment.