Skip to content

Commit

Permalink
Support parametrized event_loop fixture (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jan 25, 2022
1 parent dab3b51 commit d8efa64
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 115 deletions.
1 change: 1 addition & 0 deletions README.rst
Expand Up @@ -261,6 +261,7 @@ Changelog
~~~~~~~~~~~~~~~~~~~

- Raise a warning if @pytest.mark.asyncio is applied to non-async function. `#275 <https://github.com/pytest-dev/pytest-asyncio/issues/275>`_
- Support parametrized ``event_loop`` fixture. `#278 <https://github.com/pytest-dev/pytest-asyncio/issues/278>`_

0.17.2 (22-01-17)
~~~~~~~~~~~~~~~~~~~
Expand Down
229 changes: 114 additions & 115 deletions pytest_asyncio/plugin.py
Expand Up @@ -165,7 +165,7 @@ def _set_explicit_asyncio_mark(obj: Any) -> None:

def _is_coroutine(obj: Any) -> bool:
"""Check to see if an object is really an asyncio coroutine."""
return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
return asyncio.iscoroutinefunction(obj)


def _is_coroutine_or_asyncgen(obj: Any) -> bool:
Expand Down Expand Up @@ -198,6 +198,118 @@ def pytest_report_header(config: Config) -> List[str]:
return [f"asyncio: mode={mode}"]


def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None:
asyncio_mode = _get_asyncio_mode(config)
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
if fixturedef is holder:
continue
func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
continue
if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
elif asyncio_mode == Mode.LEGACY:
_set_explicit_asyncio_mark(func)
try:
code = func.__code__
except AttributeError:
code = func.__func__.__code__
name = (
f"<fixture {func.__qualname__}, file={code.co_filename}, "
f"line={code.co_firstlineno}>"
)
warnings.warn(
LEGACY_ASYNCIO_FIXTURE.format(name=name),
DeprecationWarning,
)

to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)

if to_add:
fixturedef.argnames += tuple(to_add)

if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)

assert _has_explicit_asyncio_mark(fixturedef.func)
holder.add(fixturedef)


def _add_kwargs(
func: Callable[..., Any],
kwargs: Dict[str, Any],
event_loop: asyncio.AbstractEventLoop,
request: SubRequest,
) -> Dict[str, Any]:
sig = inspect.signature(func)
ret = kwargs.copy()
if "request" in sig.parameters:
ret["request"] = request
if "event_loop" in sig.parameters:
ret["event_loop"] = event_loop
return ret


def _wrap_asyncgen(func: Callable[..., AsyncIterator[_R]]) -> Callable[..., _R]:
@functools.wraps(func)
def _asyncgen_fixture_wrapper(
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
) -> _R:
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))

async def setup() -> _R:
res = await gen_obj.__anext__()
return res

def finalizer() -> None:
"""Yield again, to finalize."""

async def async_finalizer() -> None:
try:
await gen_obj.__anext__()
except StopAsyncIteration:
pass
else:
msg = "Async generator fixture didn't stop."
msg += "Yield only once."
raise ValueError(msg)

event_loop.run_until_complete(async_finalizer())

result = event_loop.run_until_complete(setup())
request.addfinalizer(finalizer)
return result

return _asyncgen_fixture_wrapper


def _wrap_async(func: Callable[..., Awaitable[_R]]) -> Callable[..., _R]:
@functools.wraps(func)
def _async_fixture_wrapper(
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
) -> _R:
async def setup() -> _R:
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

return event_loop.run_until_complete(setup())

return _async_fixture_wrapper


_HOLDER: Set[FixtureDef] = set()


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
Expand All @@ -212,6 +324,7 @@ def pytest_pycollect_makeitem(
or _is_hypothesis_test(obj)
and _hypothesis_test_wraps_coroutine(obj)
):
_preprocess_async_fixtures(collector.config, _HOLDER)
item = pytest.Function.from_parent(collector, name=name)
marker = item.get_closest_marker("asyncio")
if marker is not None:
Expand All @@ -230,31 +343,6 @@ def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
return _is_coroutine(function.hypothesis.inner_test)


class FixtureStripper:
"""Include additional Fixture, and then strip them"""

EVENT_LOOP = "event_loop"

def __init__(self, fixturedef: FixtureDef) -> None:
self.fixturedef = fixturedef
self.to_strip: Set[str] = set()

def add(self, name: str) -> None:
"""Add fixture name to fixturedef
and record in to_strip list (If not previously included)"""
if name in self.fixturedef.argnames:
return
self.fixturedef.argnames += (name,)
self.to_strip.add(name)

def get_and_strip_from(self, name: str, data_dict: Dict[str, _T]) -> _T:
"""Strip name from data, and return value"""
result = data_dict[name]
if name in self.to_strip:
del data_dict[name]
return result


@pytest.hookimpl(trylast=True)
def pytest_fixture_post_finalizer(fixturedef: FixtureDef, request: SubRequest) -> None:
"""Called after fixture teardown"""
Expand Down Expand Up @@ -291,95 +379,6 @@ def pytest_fixture_setup(
policy.set_event_loop(loop)
return

func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
yield
return

config = request.node.config
asyncio_mode = _get_asyncio_mode(config)

if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
elif asyncio_mode == Mode.LEGACY:
_set_explicit_asyncio_mark(func)
try:
code = func.__code__
except AttributeError:
code = func.__func__.__code__
name = (
f"<fixture {func.__qualname__}, file={code.co_filename}, "
f"line={code.co_firstlineno}>"
)
warnings.warn(
LEGACY_ASYNCIO_FIXTURE.format(name=name),
DeprecationWarning,
)
else:
# asyncio_mode is STRICT,
# don't handle fixtures that are not explicitly marked
yield
return

if inspect.isasyncgenfunction(func):
# This is an async generator function. Wrap it accordingly.
generator = func

fixture_stripper = FixtureStripper(fixturedef)
fixture_stripper.add(FixtureStripper.EVENT_LOOP)

def wrapper(*args, **kwargs):
loop = fixture_stripper.get_and_strip_from(
FixtureStripper.EVENT_LOOP, kwargs
)

gen_obj = generator(*args, **kwargs)

async def setup():
res = await gen_obj.__anext__()
return res

def finalizer():
"""Yield again, to finalize."""

async def async_finalizer():
try:
await gen_obj.__anext__()
except StopAsyncIteration:
pass
else:
msg = "Async generator fixture didn't stop."
msg += "Yield only once."
raise ValueError(msg)

loop.run_until_complete(async_finalizer())

result = loop.run_until_complete(setup())
request.addfinalizer(finalizer)
return result

fixturedef.func = wrapper
elif inspect.iscoroutinefunction(func):
coro = func

fixture_stripper = FixtureStripper(fixturedef)
fixture_stripper.add(FixtureStripper.EVENT_LOOP)

def wrapper(*args, **kwargs):
loop = fixture_stripper.get_and_strip_from(
FixtureStripper.EVENT_LOOP, kwargs
)

async def setup():
res = await coro(*args, **kwargs)
return res

return loop.run_until_complete(setup())

fixturedef.func = wrapper
yield


Expand Down
31 changes: 31 additions & 0 deletions tests/async_fixtures/test_parametrized_loop.py
@@ -0,0 +1,31 @@
import asyncio

import pytest

TESTS_COUNT = 0


def teardown_module():
# parametrized 2 * 2 times: 2 for 'event_loop' and 2 for 'fix'
assert TESTS_COUNT == 4


@pytest.fixture(scope="module", params=[1, 2])
def event_loop(request):
request.param
loop = asyncio.new_event_loop()
yield loop
loop.close()


@pytest.fixture(params=["a", "b"])
async def fix(request):
await asyncio.sleep(0)
return request.param


@pytest.mark.asyncio
async def test_parametrized_loop(fix):
await asyncio.sleep(0)
global TESTS_COUNT
TESTS_COUNT += 1

0 comments on commit d8efa64

Please sign in to comment.