Skip to content

Commit

Permalink
Refactors for beta (#1212)
Browse files Browse the repository at this point in the history
* MyTestHandler rename

* Extract from _api_requestor.request_raw

* Fix
  • Loading branch information
richardm-stripe committed Jan 26, 2024
1 parent 9821778 commit 28d35d8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
59 changes: 51 additions & 8 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,17 @@ def request_headers(self, method, options: RequestOptions):

return headers

def request_raw(
def _args_for_request_with_retries(
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
options: Optional[RequestOptions] = None,
is_streaming: bool = False,
*,
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> Tuple[object, int, Mapping[str, str]]:
):
"""
Mechanism for issuing an API call
"""
Expand Down Expand Up @@ -446,11 +445,55 @@ def request_raw(
for key, value in supplied_headers.items():
headers[key] = value

max_network_retries = request_options.get("max_network_retries")

return (
# Actual args
method,
abs_url,
headers,
post_data,
max_network_retries,
_usage,
# For logging
encoded_params,
request_options.get("stripe_version"),
)

def request_raw(
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
options: Optional[RequestOptions] = None,
is_streaming: bool = False,
*,
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> Tuple[object, int, Mapping[str, str]]:
(
method,
abs_url,
headers,
post_data,
max_network_retries,
_usage,
encoded_params,
api_version,
) = self._args_for_request_with_retries(
method,
url,
params,
options,
base_address=base_address,
api_mode=api_mode,
_usage=_usage,
)

log_info("Request to Stripe api", method=method, url=abs_url)
log_debug(
"Post details",
post_data=encoded_params,
api_version=request_options.get("stripe_version"),
"Post details", post_data=encoded_params, api_version=api_version
)

if is_streaming:
Expand All @@ -463,7 +506,7 @@ def request_raw(
abs_url,
headers,
post_data,
max_network_retries=request_options.get("max_network_retries"),
max_network_retries=max_network_retries,
_usage=_usage,
)
else:
Expand All @@ -476,7 +519,7 @@ def request_raw(
abs_url,
headers,
post_data,
max_network_retries=request_options.get("max_network_retries"),
max_network_retries=max_network_retries,
_usage=_usage,
)

Expand Down
16 changes: 8 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from http.server import BaseHTTPRequestHandler, HTTPServer


class TestHandler(BaseHTTPRequestHandler):
class MyTestHandler(BaseHTTPRequestHandler):
num_requests = 0

requests = defaultdict(Queue)
Expand Down Expand Up @@ -118,7 +118,7 @@ def setup_mock_server(self, handler):
self.mock_server_thread.start()

def test_hits_api_base(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)
Expand All @@ -129,7 +129,7 @@ class MockServerRequestHandler(TestHandler):
assert reqs[0].path == "/v1/balance"

def test_hits_proxy_through_default_http_client(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)
Expand All @@ -150,7 +150,7 @@ class MockServerRequestHandler(TestHandler):
assert MockServerRequestHandler.num_requests == 2

def test_hits_proxy_through_custom_client(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)
Expand All @@ -164,7 +164,7 @@ class MockServerRequestHandler(TestHandler):
assert MockServerRequestHandler.num_requests == 1

def test_hits_proxy_through_stripe_client_proxy(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)
Expand All @@ -179,7 +179,7 @@ class MockServerRequestHandler(TestHandler):
assert MockServerRequestHandler.num_requests == 1

def test_hits_proxy_through_stripe_client_http_client(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)
Expand All @@ -196,7 +196,7 @@ class MockServerRequestHandler(TestHandler):
assert MockServerRequestHandler.num_requests == 1

def test_passes_client_telemetry_when_enabled(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
def do_request(self, req_num):
if req_num == 0:
time.sleep(31 / 1000) # 31 ms
Expand Down Expand Up @@ -248,7 +248,7 @@ def do_request(self, req_num):
assert "usage" not in metrics

def test_uses_thread_local_client_telemetry(self):
class MockServerRequestHandler(TestHandler):
class MockServerRequestHandler(MyTestHandler):
local_num_requests = 0
seen_metrics = set()
stats_lock = Lock()
Expand Down

0 comments on commit 28d35d8

Please sign in to comment.