diff --git a/stripe/_api_requestor.py b/stripe/_api_requestor.py index fe75e0f0b..6dd202478 100644 --- a/stripe/_api_requestor.py +++ b/stripe/_api_requestor.py @@ -366,18 +366,17 @@ def request_headers(self, method, options: RequestOptions): return headers - def request_raw( + def _args_for_request_with_retries( self, method: str, url: str, params: Optional[Mapping[str, Any]] = None, options: Optional[RequestOptions] = None, - is_streaming: bool = False, *, base_address: BaseAddress, api_mode: ApiMode, _usage: Optional[List[str]] = None, - ) -> Tuple[object, int, Mapping[str, str]]: + ): """ Mechanism for issuing an API call """ @@ -446,11 +445,55 @@ def request_raw( for key, value in supplied_headers.items(): headers[key] = value + max_network_retries = request_options.get("max_network_retries") + + return ( + # Actual args + method, + abs_url, + headers, + post_data, + max_network_retries, + _usage, + # For logging + encoded_params, + request_options.get("stripe_version"), + ) + + def request_raw( + self, + method: str, + url: str, + params: Optional[Mapping[str, Any]] = None, + options: Optional[RequestOptions] = None, + is_streaming: bool = False, + *, + base_address: BaseAddress, + api_mode: ApiMode, + _usage: Optional[List[str]] = None, + ) -> Tuple[object, int, Mapping[str, str]]: + ( + method, + abs_url, + headers, + post_data, + max_network_retries, + _usage, + encoded_params, + api_version, + ) = self._args_for_request_with_retries( + method, + url, + params, + options, + base_address=base_address, + api_mode=api_mode, + _usage=_usage, + ) + log_info("Request to Stripe api", method=method, url=abs_url) log_debug( - "Post details", - post_data=encoded_params, - api_version=request_options.get("stripe_version"), + "Post details", post_data=encoded_params, api_version=api_version ) if is_streaming: @@ -463,7 +506,7 @@ def request_raw( abs_url, headers, post_data, - max_network_retries=request_options.get("max_network_retries"), + max_network_retries=max_network_retries, _usage=_usage, ) else: @@ -476,7 +519,7 @@ def request_raw( abs_url, headers, post_data, - max_network_retries=request_options.get("max_network_retries"), + max_network_retries=max_network_retries, _usage=_usage, ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 81b81d2a8..74e44a4c4 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -20,7 +20,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer -class TestHandler(BaseHTTPRequestHandler): +class MyTestHandler(BaseHTTPRequestHandler): num_requests = 0 requests = defaultdict(Queue) @@ -118,7 +118,7 @@ def setup_mock_server(self, handler): self.mock_server_thread.start() def test_hits_api_base(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): pass self.setup_mock_server(MockServerRequestHandler) @@ -129,7 +129,7 @@ class MockServerRequestHandler(TestHandler): assert reqs[0].path == "/v1/balance" def test_hits_proxy_through_default_http_client(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): pass self.setup_mock_server(MockServerRequestHandler) @@ -150,7 +150,7 @@ class MockServerRequestHandler(TestHandler): assert MockServerRequestHandler.num_requests == 2 def test_hits_proxy_through_custom_client(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): pass self.setup_mock_server(MockServerRequestHandler) @@ -164,7 +164,7 @@ class MockServerRequestHandler(TestHandler): assert MockServerRequestHandler.num_requests == 1 def test_hits_proxy_through_stripe_client_proxy(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): pass self.setup_mock_server(MockServerRequestHandler) @@ -179,7 +179,7 @@ class MockServerRequestHandler(TestHandler): assert MockServerRequestHandler.num_requests == 1 def test_hits_proxy_through_stripe_client_http_client(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): pass self.setup_mock_server(MockServerRequestHandler) @@ -196,7 +196,7 @@ class MockServerRequestHandler(TestHandler): assert MockServerRequestHandler.num_requests == 1 def test_passes_client_telemetry_when_enabled(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): def do_request(self, req_num): if req_num == 0: time.sleep(31 / 1000) # 31 ms @@ -248,7 +248,7 @@ def do_request(self, req_num): assert "usage" not in metrics def test_uses_thread_local_client_telemetry(self): - class MockServerRequestHandler(TestHandler): + class MockServerRequestHandler(MyTestHandler): local_num_requests = 0 seen_metrics = set() stats_lock = Lock()