Skip to content

Commit

Permalink
Ensure body is consumed only once
Browse files Browse the repository at this point in the history
Fixes: kevin1024#846
Signed-off-by: Mathieu Parent <math.parent@gmail.com>
  • Loading branch information
sathieu committed Jul 21, 2024
1 parent 042e16c commit 241b0bb
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
54 changes: 53 additions & 1 deletion tests/unit/test_stubs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import http.client as httplib
from io import BytesIO
from tempfile import NamedTemporaryFile
from unittest import mock

from pytest import mark

from vcr import mode
from vcr import mode, use_cassette
from vcr.cassette import Cassette
from vcr.stubs import VCRHTTPSConnection

Expand All @@ -21,3 +24,52 @@ def testing_connect(*args):
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
vcr_connection.real_connection.connect()
assert vcr_connection.real_connection.sock is not None

def test_body_consumed_once_stream(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
BytesIO(b"1234567890"),
BytesIO(b"9876543210"),
BytesIO(b"9876543210"),
)

def test_body_consumed_once_iterator(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
iter([b"1234567890"]),
iter([b"9876543210"]),
iter([b"9876543210"]),
)

# data2 and data3 should serve the same data, potentially as iterators
def _test_body_consumed_once(
self,
tmpdir,
httpbin,
data1,
data2,
data3,
):
with NamedTemporaryFile(dir=tmpdir, suffix=".yml") as f:
testpath = f.name
# NOTE: ``use_cassette`` is not okay with the file existing
# already. So we using ``.close()`` to not only
# close but also delete the empty file, before we start.
f.close()
host, port = httpbin.host, httpbin.port
match_on = ["method", "uri", "body"]
with use_cassette(testpath, match_on=match_on):
conn1 = httplib.HTTPConnection(host, port)
conn1.request("POST", "/anything", body=data1)
conn1.getresponse()
conn2 = httplib.HTTPConnection(host, port)
conn2.request("POST", "/anything", body=data2)
conn2.getresponse()
with use_cassette(testpath, match_on=match_on) as cass:
conn3 = httplib.HTTPConnection(host, port)
conn3.request("POST", "/anything", body=data3)
conn3.getresponse()
assert cass.play_counts[0] == 0
assert cass.play_counts[1] == 1
33 changes: 33 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from io import BytesIO, StringIO

import pytest

from vcr import request
from vcr.util import read_body


@pytest.mark.parametrize(
"input_, expected_output",
[
(BytesIO(b"Stream"), b"Stream"),
(StringIO("Stream"), b"Stream"),
(iter(["StringIter"]), b"StringIter"),
(iter(["String", "Iter"]), b"StringIter"),
(iter([b"BytesIter"]), b"BytesIter"),
(iter([b"Bytes", b"Iter"]), b"BytesIter"),
(iter([70, 111, 111]), b"Foo"),
(iter([]), b""),
("String", b"String"),
(b"Bytes", b"Bytes"),
],
)
def test_read_body(input_, expected_output):
r = request.Request("POST", "http://host.com/", input_, {})
assert read_body(r) == expected_output


def test_unsupported_read_body():
r = request.Request("POST", "http://host.com/", iter([[]]), {})
with pytest.raises(ValueError) as excinfo:
assert read_body(r)
assert excinfo.value.args == ("Body type <class 'list'> not supported",)
11 changes: 9 additions & 2 deletions vcr/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import BytesIO
from urllib.parse import parse_qsl, urlparse

from .util import CaseInsensitiveDict
from .util import CaseInsensitiveDict, _is_nonsequence_iterator

log = logging.getLogger(__name__)

Expand All @@ -17,8 +17,11 @@ def __init__(self, method, uri, body, headers):
self.method = method
self.uri = uri
self._was_file = hasattr(body, "read")
self._was_iter = _is_nonsequence_iterator(body)
if self._was_file:
self.body = body.read()
elif self._was_iter:
self.body = list(body)
else:
self.body = body
self.headers = headers
Expand All @@ -36,7 +39,11 @@ def headers(self, value):

@property
def body(self):
return BytesIO(self._body) if self._was_file else self._body
if self._was_file:
return BytesIO(self._body)
if self._was_iter:
return iter(self._body)
return self._body

@body.setter
def body(self, value):
Expand Down
19 changes: 19 additions & 0 deletions vcr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,28 @@ def composed(incoming):
return composed


def _is_nonsequence_iterator(obj):
return hasattr(obj, "__iter__") and not isinstance(
obj,
(bytearray, bytes, dict, list, str),
)


def read_body(request):
if hasattr(request.body, "read"):
return request.body.read()
if _is_nonsequence_iterator(request.body):
body = list(request.body)
if body:
if isinstance(body[0], str):
return "".join(body).encode("utf-8")
elif isinstance(body[0], (bytes, bytearray)):
return b"".join(body)
elif isinstance(body[0], int):
return bytes(body)
else:
raise ValueError(f"Body type {type(body[0])} not supported")
return b""
return request.body


Expand Down

0 comments on commit 241b0bb

Please sign in to comment.