Skip to content

Commit 1a83fb0

Browse files
committed
refactor: _wrap_async_fixture forwards positional args from the original fixture function to the wrapper.
1 parent 626d44b commit 1a83fb0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

pytest_asyncio/plugin.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
)
5252

5353
if sys.version_info >= (3, 10):
54-
from typing import ParamSpec
54+
from typing import Concatenate, ParamSpec
5555
else:
56-
from typing_extensions import ParamSpec
56+
from typing_extensions import Concatenate, ParamSpec
5757

5858

5959
_ScopeName = Literal["session", "package", "module", "class", "function"]
@@ -347,15 +347,22 @@ async def async_finalizer() -> None:
347347
return _asyncgen_fixture_wrapper
348348

349349

350+
AsyncFixtureParams = ParamSpec("AsyncFixtureParams")
350351
AsyncFixtureReturnType = TypeVar("AsyncFixtureReturnType")
351352

352353

353354
def _wrap_async_fixture(
354-
fixture_function: Callable[..., CoroutineType[Any, Any, AsyncFixtureReturnType]],
355-
) -> Callable[..., AsyncFixtureReturnType]:
355+
fixture_function: Callable[
356+
AsyncFixtureParams, CoroutineType[Any, Any, AsyncFixtureReturnType]
357+
],
358+
) -> Callable[Concatenate[FixtureRequest, AsyncFixtureParams], AsyncFixtureReturnType]:
356359

357360
@functools.wraps(fixture_function) # type: ignore[arg-type]
358-
def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
361+
def _async_fixture_wrapper(
362+
request: FixtureRequest,
363+
*args: AsyncFixtureParams.args,
364+
**kwargs: AsyncFixtureParams.kwargs,
365+
):
359366
func = _perhaps_rebind_fixture_func(fixture_function, request.instance)
360367
event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture(
361368
request, func
@@ -364,7 +371,7 @@ def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
364371
kwargs.pop(event_loop_fixture_id, None)
365372

366373
async def setup():
367-
res = await func(**_add_kwargs(func, kwargs, request))
374+
res = await func(*args, **_add_kwargs(func, kwargs, request))
368375
return res
369376

370377
context = contextvars.copy_context()

0 commit comments

Comments
 (0)