Skip to content

Commit

Permalink
Testing: unify http client mock (#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe committed Feb 16, 2024
1 parent 3e2a1e3 commit aa7d8bf
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 160 deletions.
20 changes: 9 additions & 11 deletions tests/api_resources/abstract/test_custom_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_call_custom_list_method_class_paginates(self, http_client_mock):

assert ids == ["cus_1", "cus_2", "cus_3"]

def test_call_custom_stream_method_class(self, http_client_mock_streaming):
http_client_mock_streaming.stub_request(
def test_call_custom_stream_method_class(self, http_client_mock):
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -119,7 +119,7 @@ def test_call_custom_stream_method_class(self, http_client_mock_streaming):

resp = self.MyResource.do_stream_stuff("mid", foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down Expand Up @@ -150,9 +150,9 @@ def test_call_custom_method_class_with_object(self, http_client_mock):
assert obj.thing_done is True

def test_call_custom_stream_method_class_with_object(
self, http_client_mock_streaming
self, http_client_mock
):
http_client_mock_streaming.stub_request(
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -162,7 +162,7 @@ def test_call_custom_stream_method_class_with_object(
obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = self.MyResource.do_stream_stuff(obj, foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down Expand Up @@ -192,10 +192,8 @@ def test_call_custom_method_instance(self, http_client_mock):
)
assert obj.thing_done is True

def test_call_custom_stream_method_instance(
self, http_client_mock_streaming
):
http_client_mock_streaming.stub_request(
def test_call_custom_stream_method_instance(self, http_client_mock):
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -205,7 +203,7 @@ def test_call_custom_stream_method_instance(
obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = obj.do_stream_stuff(foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down
8 changes: 4 additions & 4 deletions tests/api_resources/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def test_can_list_computed_upfront_line_items_classmethod(
)
assert isinstance(resources.data[0], stripe.LineItem)

def test_can_pdf(self, setup_upload_api_base, http_client_mock_streaming):
def test_can_pdf(self, setup_upload_api_base, http_client_mock):
resource = stripe.Quote.retrieve(TEST_RESOURCE_ID)
stream = resource.pdf()
http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand All @@ -152,10 +152,10 @@ def test_can_pdf(self, setup_upload_api_base, http_client_mock_streaming):
assert content == b"Stripe binary response"

def test_can_pdf_classmethod(
self, setup_upload_api_base, http_client_mock_streaming
self, setup_upload_api_base, http_client_mock
):
stream = stripe.Quote.pdf(TEST_RESOURCE_ID)
http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand Down
27 changes: 0 additions & 27 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def http_client_mock(mocker):
stripe.default_http_client = old_client


@pytest.fixture
def http_client_mock_streaming(mocker):
mock_client = HTTPClientMock(mocker, is_streaming=True)
old_client = stripe.default_http_client
stripe.default_http_client = mock_client.get_mock_http_client()
yield mock_client
stripe.default_http_client = old_client


@pytest.fixture
def stripe_mock_stripe_client(http_client_mock):
return StripeClient(
Expand All @@ -115,21 +106,3 @@ def file_stripe_mock_stripe_client(http_client_mock):
base_addresses={"files": MOCK_API_BASE},
http_client=http_client_mock.get_mock_http_client(),
)


@pytest.fixture
def stripe_mock_stripe_client_streaming(http_client_mock_streaming):
return StripeClient(
MOCK_API_KEY,
base_addresses={"api": MOCK_API_BASE},
http_client=http_client_mock_streaming.get_mock_http_client(),
)


@pytest.fixture
def file_stripe_mock_stripe_client_streaming(http_client_mock_streaming):
return StripeClient(
MOCK_API_KEY,
base_addresses={"files": MOCK_API_BASE},
http_client=http_client_mock_streaming.get_mock_http_client(),
)
148 changes: 73 additions & 75 deletions tests/http_client_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,28 +212,19 @@ def assert_post_data(self, expected, is_json=False):


class HTTPClientMock(object):
def __init__(self, mocker, is_streaming=False, is_async=False):
if is_async:
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client_async()
)
else:
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client()
)
def __init__(self, mocker):
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client()
)

self.is_async = is_async
self.mock_client._verify_ssl_certs = True
self.mock_client.name = "mockclient"
if is_async and is_streaming:
self.func = self.mock_client.request_stream_with_retries_async
elif is_async and not is_streaming:
self.func = self.mock_client.request_with_retries_async
elif is_streaming:
self.func = self.mock_client.request_stream_with_retries
else:
self.func = self.mock_client.request_with_retries
self.registered_responses = {}
self.funcs = [
self.mock_client.request_with_retries,
self.mock_client.request_stream_with_retries,
]
self.func_call_order = []

def get_mock_http_client(self) -> Mock:
return self.mock_client
Expand All @@ -247,73 +238,78 @@ def stub_request(
rcode=200,
rheaders={},
) -> None:
def custom_side_effect(called_method, called_abs_url, *args, **kwargs):
called_path = urlsplit(called_abs_url).path
called_query = ""
if urlsplit(called_abs_url).query:
called_query = urlencode(
parse_and_sort(urlsplit(called_abs_url).query)
)
if (
called_method,
called_path,
called_query,
) not in self.registered_responses:
raise AssertionError(
"Unexpected request made to %s %s %s"
% (called_method, called_path, called_query)
)
return self.registered_responses[
(called_method, called_path, called_query)
]

async def awaitable(x):
return x
def custom_side_effect_for_func(func):
def custom_side_effect(
called_method, called_abs_url, *args, **kwargs
):
self.func_call_order.append(func)
called_path = urlsplit(called_abs_url).path
called_query = ""
if urlsplit(called_abs_url).query:
called_query = urlencode(
parse_and_sort(urlsplit(called_abs_url).query)
)
if (
called_method,
called_path,
called_query,
) not in self.registered_responses:
raise AssertionError(
"Unexpected request made to %s %s %s"
% (called_method, called_path, called_query)
)
ret = self.registered_responses[
(called_method, called_path, called_query)
]
return ret

return custom_side_effect

self.registered_responses[
(method, path, urlencode(parse_and_sort(query_string)))
] = (
awaitable(
(
rbody,
rcode,
rheaders,
)
)
if self.is_async
else (rbody, rcode, rheaders)
)
] = (rbody, rcode, rheaders)

self.func.side_effect = custom_side_effect
for func in self.funcs:
func.side_effect = custom_side_effect_for_func(func)

def get_last_call(self) -> StripeRequestCall:
if not self.func.called:
if len(self.func_call_order) == 0:
raise AssertionError(
"Expected request to have been made, but no calls were found."
)
return StripeRequestCall.from_mock_call(self.func.call_args)
return StripeRequestCall.from_mock_call(
self.func_call_order[-1].call_args
)

def get_all_calls(self) -> List[StripeRequestCall]:
calls_by_func = {
func: list(func.call_args_list) for func in self.funcs
}

calls = []
for func in self.func_call_order:
calls.append(calls_by_func[func].pop(0))

return [
StripeRequestCall.from_mock_call(call_args)
for call_args in self.func.call_args_list
StripeRequestCall.from_mock_call(call_args) for call_args in calls
]

def find_call(
self, method, api_base, path, query_string
) -> StripeRequestCall:
for call_args in self.func.call_args_list:
request_call = StripeRequestCall.from_mock_call(call_args)
try:
if request_call.check(
method=method,
api_base=api_base,
path=path,
query_string=query_string,
):
return request_call
except AssertionError:
pass
for func in self.funcs:
for call_args in func.call_args_list:
request_call = StripeRequestCall.from_mock_call(call_args)
try:
if request_call.check(
method=method,
api_base=api_base,
path=path,
query_string=query_string,
):
return request_call
except AssertionError:
pass
raise AssertionError(
"Expected request to have been made, but no calls were found."
)
Expand Down Expand Up @@ -369,13 +365,15 @@ def assert_requested(
)

def assert_no_request(self):
if self.func.called:
msg = (
"Expected no request to have been made, but %s calls were "
"found." % (self.func.call_count)
)
raise AssertionError(msg)
for func in self.funcs:
if func.called:
msg = (
"Expected no request to have been made, but %s calls were "
"found." % (sum([func.call_count for func in self.funcs]))
)
raise AssertionError(msg)

def reset_mock(self):
self.func.reset_mock()
for func in self.funcs:
func.reset_mock()
self.registered_responses = {}
10 changes: 4 additions & 6 deletions tests/services/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,11 @@ def test_can_list_computed_upfront_line_items(

def test_can_pdf(
self,
file_stripe_mock_stripe_client_streaming,
http_client_mock_streaming,
file_stripe_mock_stripe_client,
http_client_mock,
):
stream = file_stripe_mock_stripe_client_streaming.quotes.pdf(
TEST_RESOURCE_ID
)
http_client_mock_streaming.assert_requested(
stream = file_stripe_mock_stripe_client.quotes.pdf(TEST_RESOURCE_ID)
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand Down

0 comments on commit aa7d8bf

Please sign in to comment.