diff --git a/src/runwayml/_base_client.py b/src/runwayml/_base_client.py index 084b078..7257e30 100644 --- a/src/runwayml/_base_client.py +++ b/src/runwayml/_base_client.py @@ -400,14 +400,7 @@ def _make_status_error( ) -> _exceptions.APIStatusError: raise NotImplementedError() - def _remaining_retries( - self, - remaining_retries: Optional[int], - options: FinalRequestOptions, - ) -> int: - return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) - - def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: + def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers: custom_headers = options.headers or {} headers_dict = _merge_mappings(self.default_headers, custom_headers) self._validate_headers(headers_dict, custom_headers) @@ -419,6 +412,11 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key or self._idempotency_key() + # Don't set the retry count header if it was already set or removed by the caller. We check + # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. + if "x-stainless-retry-count" not in (header.lower() for header in custom_headers): + headers["x-stainless-retry-count"] = str(retries_taken) + return headers def _prepare_url(self, url: str) -> URL: @@ -440,6 +438,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: def _build_request( self, options: FinalRequestOptions, + *, + retries_taken: int = 0, ) -> httpx.Request: if log.isEnabledFor(logging.DEBUG): log.debug("Request options: %s", model_dump(options, exclude_unset=True)) @@ -455,7 +455,7 @@ def _build_request( else: raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") - headers = self._build_headers(options) + headers = self._build_headers(options, retries_taken=retries_taken) params = _merge_mappings(self.default_query, options.params) content_type = headers.get("Content-Type") files = options.files @@ -489,12 +489,17 @@ def _build_request( if not files: files = cast(HttpxRequestFiles, ForceMultipartDict()) + prepared_url = self._prepare_url(options.url) + if "_" in prepared_url.host: + # work around https://github.com/encode/httpx/discussions/2880 + kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, method=options.method, - url=self._prepare_url(options.url), + url=prepared_url, # the `Query` type that we use is incompatible with qs' # `Params` type as it needs to be typed as `Mapping[str, object]` # so that passing a `TypedDict` doesn't cause an error. @@ -933,12 +938,17 @@ def request( stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) def _request( @@ -946,7 +956,7 @@ def _request( *, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: int | None, + retries_taken: int, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: @@ -958,8 +968,8 @@ def _request( cast_to = self._maybe_override_cast_to(cast_to, options) options = self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -977,11 +987,11 @@ def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -992,11 +1002,11 @@ def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1019,13 +1029,13 @@ def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): err.response.close() return self._retry_request( input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1044,26 +1054,26 @@ def _request( response=response, stream=stream, stream_cls=stream_cls, - retries_taken=options.get_max_retries(self.max_retries) - retries, + retries_taken=retries_taken, ) def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a @@ -1073,7 +1083,7 @@ def _retry_request( return self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) @@ -1491,12 +1501,17 @@ async def request( stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return await self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) async def _request( @@ -1506,7 +1521,7 @@ async def _request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None, - remaining_retries: int | None, + retries_taken: int, ) -> ResponseT | _AsyncStreamT: if self._platform is None: # `get_platform` can make blocking IO calls so we @@ -1521,8 +1536,8 @@ async def _request( cast_to = self._maybe_override_cast_to(cast_to, options) options = await self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) await self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -1538,11 +1553,11 @@ async def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return await self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1553,11 +1568,11 @@ async def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if retries_taken > 0: return await self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1575,13 +1590,13 @@ async def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): await err.response.aclose() return await self._retry_request( input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1600,26 +1615,26 @@ async def _request( response=response, stream=stream, stream_cls=stream_cls, - retries_taken=options.get_max_retries(self.max_retries) - retries, + retries_taken=retries_taken, ) async def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_AsyncStreamT] | None, ) -> ResponseT | _AsyncStreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) @@ -1627,7 +1642,7 @@ async def _retry_request( return await self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) diff --git a/src/runwayml/_compat.py b/src/runwayml/_compat.py index 21fe694..162a6fb 100644 --- a/src/runwayml/_compat.py +++ b/src/runwayml/_compat.py @@ -136,12 +136,14 @@ def model_dump( exclude: IncEx = None, exclude_unset: bool = False, exclude_defaults: bool = False, + warnings: bool = True, ) -> dict[str, Any]: if PYDANTIC_V2: return model.model_dump( exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, + warnings=warnings, ) return cast( "dict[str, Any]", diff --git a/tests/test_client.py b/tests/test_client.py index a7f7d3e..e5f0b7c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -769,6 +769,57 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_omit_retry_count_header( + self, client: RunwayML, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/image_to_video").mock(side_effect=retry_handler) + + response = client.image_to_video.with_raw_response.create( + model="gen3a_turbo", prompt_image="https://example.com", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_overwrite_retry_count_header( + self, client: RunwayML, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/image_to_video").mock(side_effect=retry_handler) + + response = client.image_to_video.with_raw_response.create( + model="gen3a_turbo", prompt_image="https://example.com", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" class TestAsyncRunwayML: @@ -1507,3 +1558,56 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_omit_retry_count_header( + self, async_client: AsyncRunwayML, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/image_to_video").mock(side_effect=retry_handler) + + response = await client.image_to_video.with_raw_response.create( + model="gen3a_turbo", prompt_image="https://example.com", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_overwrite_retry_count_header( + self, async_client: AsyncRunwayML, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/image_to_video").mock(side_effect=retry_handler) + + response = await client.image_to_video.with_raw_response.create( + model="gen3a_turbo", prompt_image="https://example.com", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42"