diff --git a/.stats.yml b/.stats.yml index 576d7e0..d19dda0 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,2 +1,2 @@ configured_endpoints: 3 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/runwayml%2Frunwayml-e9db3689e5377f05e22b2dd594ac7eea68b859b98ba4d18103ed7376dbd43a4f.yml +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/runwayml%2Frunwayml-256f9e345a6be31b0c9fc494e70c5ff977f52316c2a7b07f388154522e74d7bb.yml diff --git a/README.md b/README.md index 5389808..df47db1 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![PyPI version](https://img.shields.io/pypi/v/runwayml.svg)](https://pypi.org/project/runwayml/) -The RunwayML Python library provides convenient access to the RunwayML REST API from any Python 3.7+ +The RunwayML Python library provides convenient access to the RunwayML REST API from any Python 3.8+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -342,7 +342,7 @@ print(runwayml.__version__) ## Requirements -Python 3.7 or higher. +Python 3.8 or higher. ## Contributing diff --git a/pyproject.toml b/pyproject.toml index 91a735a..96c9d1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,11 +16,10 @@ dependencies = [ "sniffio", "cached-property; python_version < '3.8'", ] -requires-python = ">= 3.7" +requires-python = ">= 3.8" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -63,11 +62,11 @@ format = { chain = [ "format:ruff", "format:docs", "fix:ruff", + # run formatting again to fix any inconsistencies when imports are stripped + "format:ruff", ]} -"format:black" = "black ." "format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md" "format:ruff" = "ruff format" -"format:isort" = "isort ." "lint" = { chain = [ "check:ruff", @@ -125,10 +124,6 @@ path = "README.md" pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' replacement = '[\1](https://github.com/runwayml/sdk-python/tree/main/\g<2>)' -[tool.black] -line-length = 120 -target-version = ["py37"] - [tool.pytest.ini_options] testpaths = ["tests"] addopts = "--tb=short" @@ -143,7 +138,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.7" +pythonVersion = "3.8" exclude = [ "_dev", diff --git a/requirements-dev.lock b/requirements-dev.lock index fbe99fa..023cb22 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -16,8 +16,6 @@ anyio==4.4.0 # via runwayml argcomplete==3.1.2 # via nox -attrs==23.1.0 - # via pytest certifi==2023.7.22 # via httpcore # via httpx @@ -28,8 +26,9 @@ distlib==0.3.7 # via virtualenv distro==1.8.0 # via runwayml -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio + # via pytest filelock==3.12.4 # via virtualenv h11==0.14.0 @@ -49,7 +48,7 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -mypy==1.11.2 +mypy==1.13.0 mypy-extensions==1.0.0 # via mypy nodeenv==1.8.0 @@ -60,27 +59,25 @@ packaging==23.2 # via pytest platformdirs==3.11.0 # via virtualenv -pluggy==1.3.0 - # via pytest -py==1.11.0 +pluggy==1.5.0 # via pytest -pydantic==2.7.1 +pydantic==2.9.2 # via runwayml -pydantic-core==2.18.2 +pydantic-core==2.23.4 # via pydantic pygments==2.18.0 # via rich pyright==1.1.380 -pytest==7.1.1 +pytest==8.3.3 # via pytest-asyncio -pytest-asyncio==0.21.1 +pytest-asyncio==0.24.0 python-dateutil==2.8.2 # via time-machine pytz==2023.3.post1 # via dirty-equals respx==0.20.2 rich==13.7.1 -ruff==0.6.5 +ruff==0.6.9 setuptools==68.2.2 # via nodeenv six==1.16.0 @@ -90,10 +87,10 @@ sniffio==1.3.0 # via httpx # via runwayml time-machine==2.9.0 -tomli==2.0.1 +tomli==2.0.2 # via mypy # via pytest -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via anyio # via mypy # via pydantic diff --git a/requirements.lock b/requirements.lock index 919ef6b..089546c 100644 --- a/requirements.lock +++ b/requirements.lock @@ -19,7 +19,7 @@ certifi==2023.7.22 # via httpx distro==1.8.0 # via runwayml -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio h11==0.14.0 # via httpcore @@ -30,15 +30,15 @@ httpx==0.25.2 idna==3.4 # via anyio # via httpx -pydantic==2.7.1 +pydantic==2.9.2 # via runwayml -pydantic-core==2.18.2 +pydantic-core==2.23.4 # via pydantic sniffio==1.3.0 # via anyio # via httpx # via runwayml -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via anyio # via pydantic # via pydantic-core diff --git a/src/runwayml/_base_client.py b/src/runwayml/_base_client.py index 7257e30..157e55e 100644 --- a/src/runwayml/_base_client.py +++ b/src/runwayml/_base_client.py @@ -143,6 +143,12 @@ def __init__( self.url = url self.params = params + @override + def __repr__(self) -> str: + if self.url: + return f"{self.__class__.__name__}(url={self.url})" + return f"{self.__class__.__name__}(params={self.params})" + class BasePage(GenericModel, Generic[_T]): """ @@ -689,7 +695,8 @@ def _calculate_retry_timeout( if retry_after is not None and 0 < retry_after <= 60: return retry_after - nb_retries = max_retries - remaining_retries + # Also cap retry count to 1000 to avoid any potential overflows with `pow` + nb_retries = min(max_retries - remaining_retries, 1000) # Apply exponential backoff, but not more than the max. sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) @@ -1568,7 +1575,7 @@ async def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries_taken > 0: + if remaining_retries > 0: return await self._retry_request( input_options, cast_to, diff --git a/src/runwayml/_compat.py b/src/runwayml/_compat.py index 162a6fb..4794129 100644 --- a/src/runwayml/_compat.py +++ b/src/runwayml/_compat.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload from datetime import date, datetime -from typing_extensions import Self +from typing_extensions import Self, Literal import pydantic from pydantic.fields import FieldInfo @@ -133,13 +133,15 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_dump( model: pydantic.BaseModel, *, - exclude: IncEx = None, + exclude: IncEx | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, warnings: bool = True, + mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2: + if PYDANTIC_V2 or hasattr(model, "model_dump"): return model.model_dump( + mode=mode, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, diff --git a/src/runwayml/_models.py b/src/runwayml/_models.py index d386eaa..6cb469e 100644 --- a/src/runwayml/_models.py +++ b/src/runwayml/_models.py @@ -37,6 +37,7 @@ PropertyInfo, is_list, is_given, + json_safe, lru_cache, is_mapping, parse_date, @@ -176,7 +177,7 @@ def __str__(self) -> str: # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. @classmethod @override - def construct( + def construct( # pyright: ignore[reportIncompatibleMethodOverride] cls: Type[ModelT], _fields_set: set[str] | None = None, **values: object, @@ -248,8 +249,8 @@ def model_dump( self, *, mode: Literal["json", "python"] | str = "python", - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, @@ -279,8 +280,8 @@ def model_dump( Returns: A dictionary representation of the model. """ - if mode != "python": - raise ValueError("mode is only supported in Pydantic v2") + if mode not in {"json", "python"}: + raise ValueError("mode must be either 'json' or 'python'") if round_trip != False: raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: @@ -289,7 +290,7 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") - return super().dict( # pyright: ignore[reportDeprecated] + dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, by_alias=by_alias, @@ -298,13 +299,15 @@ def model_dump( exclude_none=exclude_none, ) + return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + @override def model_dump_json( self, *, indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, diff --git a/src/runwayml/_response.py b/src/runwayml/_response.py index 56f5538..4d87ed3 100644 --- a/src/runwayml/_response.py +++ b/src/runwayml/_response.py @@ -192,6 +192,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == float: return cast(R, float(response.text)) + if cast_to == bool: + return cast(R, response.text.lower() == "true") + origin = get_origin(cast_to) or cast_to if origin == APIResponse: diff --git a/src/runwayml/_types.py b/src/runwayml/_types.py index 1648dc5..df691bf 100644 --- a/src/runwayml/_types.py +++ b/src/runwayml/_types.py @@ -16,7 +16,7 @@ Optional, Sequence, ) -from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable import httpx import pydantic @@ -193,7 +193,9 @@ def get(self, __key: str) -> str | None: ... # Note: copied from Pydantic # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 -IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +IncEx: TypeAlias = Union[ + Set[int], Set[str], Mapping[int, Union["IncEx", Literal[True]]], Mapping[str, Union["IncEx", Literal[True]]] +] PostParser = Callable[[Any], Any] diff --git a/src/runwayml/_utils/__init__.py b/src/runwayml/_utils/__init__.py index 3efe66c..a7cff3c 100644 --- a/src/runwayml/_utils/__init__.py +++ b/src/runwayml/_utils/__init__.py @@ -6,6 +6,7 @@ is_list as is_list, is_given as is_given, is_tuple as is_tuple, + json_safe as json_safe, lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, diff --git a/src/runwayml/_utils/_transform.py b/src/runwayml/_utils/_transform.py index 47e262a..d7c0534 100644 --- a/src/runwayml/_utils/_transform.py +++ b/src/runwayml/_utils/_transform.py @@ -173,6 +173,11 @@ def _transform_recursive( # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + inner_type = extract_type_arg(stripped_type, 0) return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] @@ -186,7 +191,7 @@ def _transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: @@ -324,7 +329,7 @@ async def _async_transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: diff --git a/src/runwayml/_utils/_utils.py b/src/runwayml/_utils/_utils.py index 0bba17c..e5811bb 100644 --- a/src/runwayml/_utils/_utils.py +++ b/src/runwayml/_utils/_utils.py @@ -16,6 +16,7 @@ overload, ) from pathlib import Path +from datetime import date, datetime from typing_extensions import TypeGuard import sniffio @@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: maxsize=maxsize, ) return cast(Any, wrapper) # type: ignore[no-any-return] + + +def json_safe(data: object) -> object: + """Translates a mapping / sequence recursively in the same fashion + as `pydantic` v2's `model_dump(mode="json")`. + """ + if is_mapping(data): + return {json_safe(key): json_safe(value) for key, value in data.items()} + + if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)): + return [json_safe(item) for item in data] + + if isinstance(data, (datetime, date)): + return data.isoformat() + + return data diff --git a/src/runwayml/resources/image_to_video.py b/src/runwayml/resources/image_to_video.py index 746d9b9..2394866 100644 --- a/src/runwayml/resources/image_to_video.py +++ b/src/runwayml/resources/image_to_video.py @@ -2,6 +2,7 @@ from __future__ import annotations +from typing import Union, Iterable from typing_extensions import Literal import httpx @@ -50,10 +51,10 @@ def create( self, *, model: Literal["gen3a_turbo"], - prompt_image: str, + prompt_image: Union[str, Iterable[image_to_video_create_params.PromptImagePromptImage]], duration: Literal[5, 10] | NotGiven = NOT_GIVEN, prompt_text: str | NotGiven = NOT_GIVEN, - ratio: Literal["16:9", "9:16"] | NotGiven = NOT_GIVEN, + ratio: Literal["1280:768", "768:1280"] | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, watermark: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -69,14 +70,15 @@ def create( Args: model: The model variant to use. - prompt_image: A HTTPS URL pointing to an image. Images must be JPEG, PNG, or WebP and are - limited to 16MB. Responses must include a valid `Content-Length` header. + prompt_image: A HTTPS URL or data URI containing an encoded image to be used as the first + frame of the generated video. See [our docs](/assets/inputs#images) on image + inputs for more information. duration: The number of seconds of duration for the output video. prompt_text - ratio: The aspect ratio of the output video. + ratio seed: If unspecified, a random number is chosen. Varying the seed integer is a way to get different results for the same other request parameters. Using the same seed @@ -138,10 +140,10 @@ async def create( self, *, model: Literal["gen3a_turbo"], - prompt_image: str, + prompt_image: Union[str, Iterable[image_to_video_create_params.PromptImagePromptImage]], duration: Literal[5, 10] | NotGiven = NOT_GIVEN, prompt_text: str | NotGiven = NOT_GIVEN, - ratio: Literal["16:9", "9:16"] | NotGiven = NOT_GIVEN, + ratio: Literal["1280:768", "768:1280"] | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, watermark: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -157,14 +159,15 @@ async def create( Args: model: The model variant to use. - prompt_image: A HTTPS URL pointing to an image. Images must be JPEG, PNG, or WebP and are - limited to 16MB. Responses must include a valid `Content-Length` header. + prompt_image: A HTTPS URL or data URI containing an encoded image to be used as the first + frame of the generated video. See [our docs](/assets/inputs#images) on image + inputs for more information. duration: The number of seconds of duration for the output video. prompt_text - ratio: The aspect ratio of the output video. + ratio seed: If unspecified, a random number is chosen. Varying the seed integer is a way to get different results for the same other request parameters. Using the same seed diff --git a/src/runwayml/types/image_to_video_create_params.py b/src/runwayml/types/image_to_video_create_params.py index 8ac74f8..d6a0d01 100644 --- a/src/runwayml/types/image_to_video_create_params.py +++ b/src/runwayml/types/image_to_video_create_params.py @@ -2,22 +2,23 @@ from __future__ import annotations +from typing import Union, Iterable from typing_extensions import Literal, Required, Annotated, TypedDict from .._utils import PropertyInfo -__all__ = ["ImageToVideoCreateParams"] +__all__ = ["ImageToVideoCreateParams", "PromptImagePromptImage"] class ImageToVideoCreateParams(TypedDict, total=False): model: Required[Literal["gen3a_turbo"]] """The model variant to use.""" - prompt_image: Required[Annotated[str, PropertyInfo(alias="promptImage")]] - """A HTTPS URL pointing to an image. - - Images must be JPEG, PNG, or WebP and are limited to 16MB. Responses must - include a valid `Content-Length` header. + prompt_image: Required[Annotated[Union[str, Iterable[PromptImagePromptImage]], PropertyInfo(alias="promptImage")]] + """ + A HTTPS URL or data URI containing an encoded image to be used as the first + frame of the generated video. See [our docs](/assets/inputs#images) on image + inputs for more information. """ duration: Literal[5, 10] @@ -25,8 +26,7 @@ class ImageToVideoCreateParams(TypedDict, total=False): prompt_text: Annotated[str, PropertyInfo(alias="promptText")] - ratio: Literal["16:9", "9:16"] - """The aspect ratio of the output video.""" + ratio: Literal["1280:768", "768:1280"] seed: int """If unspecified, a random number is chosen. @@ -41,3 +41,18 @@ class ImageToVideoCreateParams(TypedDict, total=False): A boolean indicating whether or not the output video will contain a Runway watermark. """ + + +class PromptImagePromptImage(TypedDict, total=False): + position: Required[Literal["first", "last"]] + """The position of the image in the output video. + + "first" will use the image as the first frame of the video, "last" will use the + image as the last frame of the video. + """ + + uri: Required[str] + """A HTTPS URL or data URI containing an encoded image. + + See [our docs](/assets/inputs#images) on image inputs for more information. + """ diff --git a/src/runwayml/types/image_to_video_create_response.py b/src/runwayml/types/image_to_video_create_response.py index ae3c00d..9814879 100644 --- a/src/runwayml/types/image_to_video_create_response.py +++ b/src/runwayml/types/image_to_video_create_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["ImageToVideoCreateResponse"] @@ -9,4 +8,7 @@ class ImageToVideoCreateResponse(BaseModel): id: str - """The ID of the newly created task.""" + """The ID of the newly created task. + + Use this ID to query the task status and retrieve the generated video. + """ diff --git a/src/runwayml/types/task_retrieve_response.py b/src/runwayml/types/task_retrieve_response.py index 69866b8..47b9601 100644 --- a/src/runwayml/types/task_retrieve_response.py +++ b/src/runwayml/types/task_retrieve_response.py @@ -13,6 +13,7 @@ class TaskRetrieveResponse(BaseModel): id: str + """The ID of the task being returned.""" created_at: datetime = FieldInfo(alias="createdAt") """The timestamp that the task was submitted at.""" diff --git a/tests/api_resources/test_image_to_video.py b/tests/api_resources/test_image_to_video.py index f8dc1c3..8c8f0f7 100644 --- a/tests/api_resources/test_image_to_video.py +++ b/tests/api_resources/test_image_to_video.py @@ -32,7 +32,7 @@ def test_method_create_with_all_params(self, client: RunwayML) -> None: prompt_image="https://example.com", duration=5, prompt_text="promptText", - ratio="16:9", + ratio="1280:768", seed=0, watermark=True, ) @@ -83,7 +83,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncRunwayML) prompt_image="https://example.com", duration=5, prompt_text="promptText", - ratio="16:9", + ratio="1280:768", seed=0, watermark=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index 897b00a..ff25164 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ from __future__ import annotations import os -import asyncio import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator import pytest +from pytest_asyncio import is_async_test from runwayml import RunwayML, AsyncRunwayML @@ -17,11 +17,13 @@ logging.getLogger("runwayml").setLevel(logging.DEBUG) -@pytest.fixture(scope="session") -def event_loop() -> Iterator[asyncio.AbstractEventLoop]: - loop = asyncio.new_event_loop() - yield loop - loop.close() +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") diff --git a/tests/test_client.py b/tests/test_client.py index e5f0b7c..af65c3f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,6 +10,7 @@ import tracemalloc from typing import Any, Union, cast from unittest import mock +from typing_extensions import Literal import httpx import pytest @@ -692,6 +693,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -750,7 +752,14 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retries_taken(self, client: RunwayML, failures_before_success: int, respx_mock: MockRouter) -> None: + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + def test_retries_taken( + self, + client: RunwayML, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: client = client.with_options(max_retries=4) nb_retries = 0 @@ -759,6 +768,8 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: nonlocal nb_retries if nb_retries < failures_before_success: nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") return httpx.Response(500) return httpx.Response(200) @@ -1477,6 +1488,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -1537,8 +1549,13 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("runwayml._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( - self, async_client: AsyncRunwayML, failures_before_success: int, respx_mock: MockRouter + self, + async_client: AsyncRunwayML, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, ) -> None: client = async_client.with_options(max_retries=4) @@ -1548,6 +1565,8 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: nonlocal nb_retries if nb_retries < failures_before_success: nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") return httpx.Response(500) return httpx.Response(200) diff --git a/tests/test_models.py b/tests/test_models.py index 5fc477c..8d19f7c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -245,7 +245,7 @@ class Model(BaseModel): assert m.foo is True m = Model.construct(foo="CARD_HOLDER") - assert m.foo is "CARD_HOLDER" + assert m.foo == "CARD_HOLDER" m = Model.construct(foo={"bar": False}) assert isinstance(m.foo, Submodel1) @@ -520,19 +520,15 @@ class Model(BaseModel): assert m3.to_dict(exclude_none=True) == {} assert m3.to_dict(exclude_defaults=True) == {} - if PYDANTIC_V2: - - class Model2(BaseModel): - created_at: datetime + class Model2(BaseModel): + created_at: datetime - time_str = "2024-03-21T11:39:01.275859" - m4 = Model2.construct(created_at=time_str) - assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} - assert m4.to_dict(mode="json") == {"created_at": time_str} - else: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.to_dict(mode="json") + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + if not PYDANTIC_V2: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -558,9 +554,6 @@ class Model(BaseModel): assert m3.model_dump(exclude_none=True) == {} if not PYDANTIC_V2: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.model_dump(mode="json") - with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) diff --git a/tests/test_response.py b/tests/test_response.py index a9918c7..6f6bf75 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -190,6 +190,56 @@ async def test_async_response_parse_annotated_type(async_client: AsyncRunwayML) assert obj.bar == 2 +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: RunwayML, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncRunwayML, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + class OtherModel(BaseModel): a: str diff --git a/tests/test_transform.py b/tests/test_transform.py index 48b05f6..8e28609 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False): foo: Annotated[date, PropertyInfo(format="iso8601")] +class DatetimeModel(BaseModel): + foo: datetime + + +class DateModel(BaseModel): + foo: Optional[date] + + @parametrize @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + tz = "Z" if PYDANTIC_V2 else "+00:00" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] dt = dt.replace(tzinfo=None) assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == { + "foo": "2023-02-23" + } # type: ignore[comparison-overlap] @parametrize