Skip to content

Commit

Permalink
feat: Support trio
Browse files Browse the repository at this point in the history
Ref: #1872

Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
  • Loading branch information
Stranger6667 committed Feb 11, 2024
1 parent bf0b939 commit 074f809
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 31 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Expand Up @@ -4,6 +4,10 @@ Changelog
:version:`Unreleased <v3.25.1...HEAD>` - TBD
--------------------------------------------

**Added**

- Support running async Python tests with ``trio``. :issue:`1872`

.. _v3.25.1:

:version:`3.25.1 <v3.25.0...v3.25.1>` - 2024-02-10
Expand Down
19 changes: 19 additions & 0 deletions docs/python.rst
Expand Up @@ -499,6 +499,25 @@ If you don't supply the ``app`` argument to the loader, make sure you pass your
# The `session` argument must be supplied.
case.call_and_validate(session=client)
Async support
-------------
Schemathesis supports asynchronous test functions executed via ``asyncio`` or ``trio``.
They work the same way as regular async tests and don't require any additional configuration beyond
installing ``pytest-asyncio`` or ``pytest-trio`` and follwing their usage guidelines.
.. code:: python
import pytest
import schemathesis
schema = ...
@pytest.mark.trio
@schema.parametrize()
async def test_api(case):
...
Unittest support
----------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -46,7 +46,6 @@ dependencies = [
"pytest>=4.6.4,<9",
"PyYAML>=5.1,<7.0",
"requests>=2.22,<3",
"sniffio>=1.2,<2.0",
"starlette>=0.13,<1",
"starlette-testclient<1",
"tomli-w>=1.0.0,<2.0",
Expand All @@ -64,6 +63,7 @@ tests = [
"pytest-asyncio>=0.18.0,<1.0",
"pytest-httpserver>=1.0,<2.0",
"pytest-mock>=3.7.0,<4.0",
"pytest-trio>=0.8,<1.0",
"pytest-xdist>=3,<4.0",
"strawberry-graphql[fastapi]>=0.109.0",
"syrupy>=2,<5.0",
Expand Down
53 changes: 28 additions & 25 deletions src/schemathesis/_hypothesis.py
@@ -1,27 +1,25 @@
"""High-level API for creating Hypothesis tests."""
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Callable, Optional, Mapping, Generator, Tuple
from functools import partial
from typing import Any, Callable, Generator, Mapping, Optional, Tuple

import anyio
import hypothesis
import sniffio
from hypothesis import Phase
from hypothesis import strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.reflection import proxies
from jsonschema.exceptions import SchemaError

from .auths import get_auth_storage_from_test
from .generation import DataGenerationMethod, GenerationConfig
from .constants import DEFAULT_DEADLINE
from .exceptions import OperationSchemaError, SerializationNotPossible
from .generation import DataGenerationMethod, GenerationConfig
from .hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher
from .models import APIOperation, Case
from .transports.content_types import parse_content_type
from .transports.headers import is_latin_1_encodable, has_invalid_characters
from .transports.headers import has_invalid_characters, is_latin_1_encodable
from .utils import GivenInput, combine_strategies


Expand All @@ -34,6 +32,7 @@ def create_test(
data_generation_methods: list[DataGenerationMethod],
generation_config: GenerationConfig | None = None,
as_strategy_kwargs: dict[str, Any] | None = None,
keep_async_fn: bool = False,
_given_args: tuple[GivenInput, ...] = (),
_given_kwargs: dict[str, GivenInput] | None = None,
) -> Callable:
Expand All @@ -59,16 +58,28 @@ def create_test(
# tests in multiple threads because Hypothesis stores some internal attributes on function objects and re-writing
# them from different threads may lead to unpredictable side-effects.

@proxies(test) # type: ignore
def test_function(*args: Any, **kwargs: Any) -> Any:
__tracebackhide__ = True
return test(*args, **kwargs)
if keep_async_fn:

@proxies(test) # type: ignore
async def test_function(*args: Any, **kwargs: Any) -> Any:
__tracebackhide__ = True
return test(*args, **kwargs)

Check warning on line 66 in src/schemathesis/_hypothesis.py

View check run for this annotation

Codecov / codecov/patch

src/schemathesis/_hypothesis.py#L65-L66

Added lines #L65 - L66 were not covered by tests
else:

@proxies(test) # type: ignore
def test_function(*args: Any, **kwargs: Any) -> Any:
__tracebackhide__ = True
return test(*args, **kwargs)

wrapped_test = hypothesis.given(*_given_args, **_given_kwargs)(test_function)
if seed is not None:
wrapped_test = hypothesis.seed(seed)(wrapped_test)
if asyncio.iscoroutinefunction(test):
wrapped_test.hypothesis.inner_test = make_async_test(test) # type: ignore
# `pytest-trio` expects a coroutine function
if keep_async_fn:
wrapped_test.hypothesis.inner_test = test # type: ignore
else:
wrapped_test.hypothesis.inner_test = make_async_test(test) # type: ignore
setup_default_deadline(wrapped_test)
if settings is not None:
wrapped_test = settings(wrapped_test)
Expand Down Expand Up @@ -107,20 +118,12 @@ def _get_hypothesis_settings(test: Callable) -> hypothesis.settings | None:
def make_async_test(test: Callable) -> Callable:
def async_run(*args: Any, **kwargs: Any) -> None:
try:
current_async_library = sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
current_async_library = None

if current_async_library == "trio":
anyio.run(partial(test, *args, **kwargs))
else:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
coro = test(*args, **kwargs)
future = asyncio.ensure_future(coro, loop=loop)
loop.run_until_complete(future)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
coro = test(*args, **kwargs)
future = asyncio.ensure_future(coro, loop=loop)
loop.run_until_complete(future)

return async_run

Expand Down
7 changes: 7 additions & 0 deletions src/schemathesis/extra/pytest_plugin.py
Expand Up @@ -114,6 +114,12 @@ def _gen_items(
"""
from .._hypothesis import create_test

is_trio_test = False
for mark in getattr(self.test_function, "pytestmark", []):
if mark.name == "trio":
is_trio_test = True
break

if isinstance(result, Ok):
operation = result.ok()
if self.is_invalid_test:
Expand All @@ -135,6 +141,7 @@ def _gen_items(
data_generation_methods=self.schemathesis_case.data_generation_methods,
generation_config=self.schemathesis_case.generation_config,
as_strategy_kwargs=as_strategy_kwargs,
keep_async_fn=is_trio_test,
)
name = self._get_test_name(operation)
else:
Expand Down
29 changes: 24 additions & 5 deletions test/test_async.py
Expand Up @@ -3,29 +3,46 @@
from .utils import integer


@pytest.fixture(params=["aiohttp.pytest_plugin", "pytest_asyncio"])
ALL_PLUGINS = {"aiohttp.pytest_plugin": "", "asyncio": "@pytest.mark.asyncio", "trio": "@pytest.mark.trio"}


def build_pytest_args(plugin):
disabled_plugins = set(ALL_PLUGINS) - {plugin}
args = ["-v"]
for disabled in disabled_plugins:
args.extend(("-p", f"no:{disabled}"))
return args


@pytest.fixture(params=list(ALL_PLUGINS))
def plugin(request):
return request.param


def test_simple(testdir, plugin):
# When the wrapped test is a coroutine function and pytest-aiohttp/asyncio plugin is used
marker = ALL_PLUGINS[plugin]
testdir.make_test(
f"""
async def func():
return 1
{marker}
@schema.parametrize()
{"@pytest.mark.asyncio" if plugin == "pytest_asyncio" else ""}
async def test_(request, case):
request.config.HYPOTHESIS_CASES += 1
assert case.full_path == "/v1/users"
assert case.method == "GET"
assert await func() == 1
if "{plugin}" == "trio":
import trio
await trio.sleep(0)
""",
pytest_plugins=[plugin],
)
result = testdir.runpytest("-v")
args = build_pytest_args(plugin)
result = testdir.runpytest(*args)
result.assert_outcomes(passed=1)
# Then it should be executed as any other test
result.stdout.re_match_lines([r"test_simple.py::test_\[GET /v1/users\] PASSED", r"Hypothesis calls: 1"])
Expand All @@ -34,10 +51,11 @@ async def test_(request, case):
def test_settings_first(testdir, plugin):
# When `hypothesis.settings` decorator is applied to a coroutine-based test before `parametrize`
parameters = {"parameters": [integer(name="id", required=True)]}
marker = ALL_PLUGINS[plugin]
testdir.make_test(
f"""
@schema.parametrize()
{"@pytest.mark.asyncio" if plugin == "pytest_asyncio" else ""}
{marker}
@settings(max_examples=5)
async def test_(request, case):
request.config.HYPOTHESIS_CASES += 1
Expand All @@ -47,7 +65,8 @@ async def test_(request, case):
pytest_plugins=[plugin],
paths={"/users": {"get": parameters, "post": parameters}},
)
result = testdir.runpytest("-v", "-s")
args = build_pytest_args(plugin)
result = testdir.runpytest(*args)
result.assert_outcomes(passed=2)
# Then it should be executed as any other test
result.stdout.re_match_lines([r"Hypothesis calls: 10$"])
Expand Down

0 comments on commit 074f809

Please sign in to comment.