diff --git a/docs/changelog.rst b/docs/changelog.rst index ea4142b69b..9b5ea0a011 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog :version:`Unreleased ` - TBD -------------------------------------------- +**Added** + +- Support running async Python tests with ``trio``. :issue:`1872` + .. _v3.25.1: :version:`3.25.1 ` - 2024-02-10 diff --git a/docs/python.rst b/docs/python.rst index d88f577059..03ab251d6d 100644 --- a/docs/python.rst +++ b/docs/python.rst @@ -499,6 +499,25 @@ If you don't supply the ``app`` argument to the loader, make sure you pass your # The `session` argument must be supplied. case.call_and_validate(session=client) +Async support +------------- + +Schemathesis supports asynchronous test functions executed via ``asyncio`` or ``trio``. +They work the same way as regular async tests and don't require any additional configuration beyond +installing ``pytest-asyncio`` or ``pytest-trio`` and follwing their usage guidelines. + +.. code:: python + + import pytest + import schemathesis + + schema = ... + + @pytest.mark.trio + @schema.parametrize() + async def test_api(case): + ... + Unittest support ---------------- diff --git a/pyproject.toml b/pyproject.toml index 4d98e1f90c..7da6c69105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ "pytest>=4.6.4,<9", "PyYAML>=5.1,<7.0", "requests>=2.22,<3", - "sniffio>=1.2,<2.0", "starlette>=0.13,<1", "starlette-testclient<1", "tomli-w>=1.0.0,<2.0", @@ -64,6 +63,7 @@ tests = [ "pytest-asyncio>=0.18.0,<1.0", "pytest-httpserver>=1.0,<2.0", "pytest-mock>=3.7.0,<4.0", + "pytest-trio>=0.8,<1.0", "pytest-xdist>=3,<4.0", "strawberry-graphql[fastapi]>=0.109.0", "syrupy>=2,<5.0", diff --git a/src/schemathesis/_hypothesis.py b/src/schemathesis/_hypothesis.py index 337fbf724c..bea50f5f1b 100644 --- a/src/schemathesis/_hypothesis.py +++ b/src/schemathesis/_hypothesis.py @@ -1,13 +1,11 @@ """High-level API for creating Hypothesis tests.""" from __future__ import annotations + import asyncio import warnings -from typing import Any, Callable, Optional, Mapping, Generator, Tuple -from functools import partial +from typing import Any, Callable, Generator, Mapping, Optional, Tuple -import anyio import hypothesis -import sniffio from hypothesis import Phase from hypothesis import strategies as st from hypothesis.errors import HypothesisWarning, Unsatisfiable @@ -15,13 +13,13 @@ from jsonschema.exceptions import SchemaError from .auths import get_auth_storage_from_test -from .generation import DataGenerationMethod, GenerationConfig from .constants import DEFAULT_DEADLINE from .exceptions import OperationSchemaError, SerializationNotPossible +from .generation import DataGenerationMethod, GenerationConfig from .hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher from .models import APIOperation, Case from .transports.content_types import parse_content_type -from .transports.headers import is_latin_1_encodable, has_invalid_characters +from .transports.headers import has_invalid_characters, is_latin_1_encodable from .utils import GivenInput, combine_strategies @@ -34,6 +32,7 @@ def create_test( data_generation_methods: list[DataGenerationMethod], generation_config: GenerationConfig | None = None, as_strategy_kwargs: dict[str, Any] | None = None, + keep_async_fn: bool = False, _given_args: tuple[GivenInput, ...] = (), _given_kwargs: dict[str, GivenInput] | None = None, ) -> Callable: @@ -59,16 +58,28 @@ def create_test( # tests in multiple threads because Hypothesis stores some internal attributes on function objects and re-writing # them from different threads may lead to unpredictable side-effects. - @proxies(test) # type: ignore - def test_function(*args: Any, **kwargs: Any) -> Any: - __tracebackhide__ = True - return test(*args, **kwargs) + if keep_async_fn: + + @proxies(test) # type: ignore + async def test_function(*args: Any, **kwargs: Any) -> Any: + __tracebackhide__ = True + return test(*args, **kwargs) + else: + + @proxies(test) # type: ignore + def test_function(*args: Any, **kwargs: Any) -> Any: + __tracebackhide__ = True + return test(*args, **kwargs) wrapped_test = hypothesis.given(*_given_args, **_given_kwargs)(test_function) if seed is not None: wrapped_test = hypothesis.seed(seed)(wrapped_test) if asyncio.iscoroutinefunction(test): - wrapped_test.hypothesis.inner_test = make_async_test(test) # type: ignore + # `pytest-trio` expects a coroutine function + if keep_async_fn: + wrapped_test.hypothesis.inner_test = test # type: ignore + else: + wrapped_test.hypothesis.inner_test = make_async_test(test) # type: ignore setup_default_deadline(wrapped_test) if settings is not None: wrapped_test = settings(wrapped_test) @@ -107,20 +118,12 @@ def _get_hypothesis_settings(test: Callable) -> hypothesis.settings | None: def make_async_test(test: Callable) -> Callable: def async_run(*args: Any, **kwargs: Any) -> None: try: - current_async_library = sniffio.current_async_library() - except sniffio.AsyncLibraryNotFoundError: - current_async_library = None - - if current_async_library == "trio": - anyio.run(partial(test, *args, **kwargs)) - else: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - coro = test(*args, **kwargs) - future = asyncio.ensure_future(coro, loop=loop) - loop.run_until_complete(future) + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + coro = test(*args, **kwargs) + future = asyncio.ensure_future(coro, loop=loop) + loop.run_until_complete(future) return async_run diff --git a/src/schemathesis/extra/pytest_plugin.py b/src/schemathesis/extra/pytest_plugin.py index f8bd31b0b5..0cc4040f5b 100644 --- a/src/schemathesis/extra/pytest_plugin.py +++ b/src/schemathesis/extra/pytest_plugin.py @@ -114,6 +114,12 @@ def _gen_items( """ from .._hypothesis import create_test + is_trio_test = False + for mark in getattr(self.test_function, "pytestmark", []): + if mark.name == "trio": + is_trio_test = True + break + if isinstance(result, Ok): operation = result.ok() if self.is_invalid_test: @@ -135,6 +141,7 @@ def _gen_items( data_generation_methods=self.schemathesis_case.data_generation_methods, generation_config=self.schemathesis_case.generation_config, as_strategy_kwargs=as_strategy_kwargs, + keep_async_fn=is_trio_test, ) name = self._get_test_name(operation) else: diff --git a/test/test_async.py b/test/test_async.py index 30f84a45dc..2c66a7a40a 100644 --- a/test/test_async.py +++ b/test/test_async.py @@ -3,29 +3,46 @@ from .utils import integer -@pytest.fixture(params=["aiohttp.pytest_plugin", "pytest_asyncio"]) +ALL_PLUGINS = {"aiohttp.pytest_plugin": "", "asyncio": "@pytest.mark.asyncio", "trio": "@pytest.mark.trio"} + + +def build_pytest_args(plugin): + disabled_plugins = set(ALL_PLUGINS) - {plugin} + args = ["-v"] + for disabled in disabled_plugins: + args.extend(("-p", f"no:{disabled}")) + return args + + +@pytest.fixture(params=list(ALL_PLUGINS)) def plugin(request): return request.param def test_simple(testdir, plugin): # When the wrapped test is a coroutine function and pytest-aiohttp/asyncio plugin is used + marker = ALL_PLUGINS[plugin] testdir.make_test( f""" async def func(): return 1 +{marker} @schema.parametrize() -{"@pytest.mark.asyncio" if plugin == "pytest_asyncio" else ""} async def test_(request, case): request.config.HYPOTHESIS_CASES += 1 assert case.full_path == "/v1/users" assert case.method == "GET" assert await func() == 1 + if "{plugin}" == "trio": + import trio + + await trio.sleep(0) """, pytest_plugins=[plugin], ) - result = testdir.runpytest("-v") + args = build_pytest_args(plugin) + result = testdir.runpytest(*args) result.assert_outcomes(passed=1) # Then it should be executed as any other test result.stdout.re_match_lines([r"test_simple.py::test_\[GET /v1/users\] PASSED", r"Hypothesis calls: 1"]) @@ -34,10 +51,11 @@ async def test_(request, case): def test_settings_first(testdir, plugin): # When `hypothesis.settings` decorator is applied to a coroutine-based test before `parametrize` parameters = {"parameters": [integer(name="id", required=True)]} + marker = ALL_PLUGINS[plugin] testdir.make_test( f""" @schema.parametrize() -{"@pytest.mark.asyncio" if plugin == "pytest_asyncio" else ""} +{marker} @settings(max_examples=5) async def test_(request, case): request.config.HYPOTHESIS_CASES += 1 @@ -47,7 +65,8 @@ async def test_(request, case): pytest_plugins=[plugin], paths={"/users": {"get": parameters, "post": parameters}}, ) - result = testdir.runpytest("-v", "-s") + args = build_pytest_args(plugin) + result = testdir.runpytest(*args) result.assert_outcomes(passed=2) # Then it should be executed as any other test result.stdout.re_match_lines([r"Hypothesis calls: 10$"])