Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix contextvars not propagated from fixture to test #161

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def transfer_markers(*args, **kwargs): # noqa
except ImportError:
from inspect import isasyncgenfunction

try:
import contextvars
except ImportError:
contextvars = None


def _is_coroutine(obj):
"""Check to see if an object is really an asyncio coroutine."""
Expand Down Expand Up @@ -48,6 +53,12 @@ def pytest_pycollect_makeitem(collector, name, obj):
return list(collector._genfunctions(name, obj))


current_context = None


def apply_context(context):
for var in context:
var.set(context[var])
class FixtureStripper:
"""Include additional Fixture, and then strip them"""
REQUEST = "request"
Expand Down Expand Up @@ -91,6 +102,10 @@ def pytest_fixture_setup(fixturedef, request):
policy.set_event_loop(loop)
return

if current_context:
# Apply the current context
apply_context(current_context)

if isasyncgenfunction(fixturedef.func):
# This is an async generator function. Wrap it accordingly.
generator = fixturedef.func
Expand All @@ -108,7 +123,9 @@ def wrapper(*args, **kwargs):

async def setup():
res = await gen_obj.__anext__()
return res
# return the current context
# that is maybe modified by async gen_obj
return res, contextvars and contextvars.copy_context()

def finalizer():
"""Yield again, to finalize."""
Expand All @@ -124,7 +141,15 @@ async def async_finalizer():
loop.run_until_complete(async_finalizer())

request.addfinalizer(finalizer)
return loop.run_until_complete(setup())

res, context = asyncio.get_event_loop().run_until_complete(setup())
if context:
# Store the current context
global current_context

current_context = context

return res

fixturedef.func = wrapper
elif inspect.iscoroutinefunction(fixturedef.func):
Expand Down Expand Up @@ -152,6 +177,11 @@ def pytest_pyfunc_call(pyfuncitem):
Run asyncio marked test functions in an event loop instead of a normal
function call.
"""
global current_context
if current_context:
# Apply the current context
apply_context(current_context)

if 'asyncio' in pyfuncitem.keywords:
if getattr(pyfuncitem.obj, 'is_hypothesis_test', False):
pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync(
Expand All @@ -164,6 +194,8 @@ def pytest_pyfunc_call(pyfuncitem):
_loop=pyfuncitem.funcargs['event_loop']
)
yield
# Cleanup the current context
current_context = None


def wrap_in_sync(func, _loop):
Expand Down
116 changes: 116 additions & 0 deletions tests/async_fixtures/test_async_gen_fixtures_within_context_37.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import unittest.mock

import pytest

START = object()
END = object()
RETVAL = object()


@pytest.fixture(scope="module")
def mock():
return unittest.mock.Mock(return_value=RETVAL)


@pytest.fixture
def var():
contextvars = pytest.importorskip("contextvars")

return contextvars.ContextVar("var_1")


@pytest.fixture
async def async_gen_fixture_within_context(mock, var):
var.set(1)
try:
yield mock(START)
except Exception as e:
mock(e)
else:
mock(END)


@pytest.mark.asyncio
async def test_async_gen_fixture_within_context(
async_gen_fixture_within_context, mock, var
):
assert var.get() == 1
assert mock.called
assert mock.call_args_list[-1] == unittest.mock.call(START)
assert async_gen_fixture_within_context is RETVAL


@pytest.mark.asyncio
async def test_async_gen_fixture_within_context_finalized(mock, var):
with pytest.raises(LookupError):
var.get()

try:
assert mock.called
assert mock.call_args_list[-1] == unittest.mock.call(END)
finally:
mock.reset_mock()


@pytest.fixture
async def async_gen_fixture_1(var):
var.set(1)
yield


@pytest.fixture
async def async_gen_fixture_2(async_gen_fixture_1, var):
assert var.get() == 1
var.set(2)
yield


@pytest.mark.asyncio
async def test_context_overwrited_by_another_async_gen_fixture(
async_gen_fixture_2, var
):
assert var.get() == 2


@pytest.fixture
async def async_fixture_within_context(async_gen_fixture_1, var):
assert var.get() == 1


@pytest.fixture
def fixture_within_context(async_gen_fixture_1, var):
assert var.get() == 1


@pytest.mark.asyncio
async def test_context_propagated_from_gen_fixture_to_normal_fixture(
fixture_within_context, async_fixture_within_context
):
pass


@pytest.fixture
def var_2():
contextvars = pytest.importorskip("contextvars")

return contextvars.ContextVar("var_2")


@pytest.fixture
async def async_gen_fixture_set_var_1(var):
var.set(1)
yield


@pytest.fixture
async def async_gen_fixture_set_var_2(var_2):
var_2.set(2)
yield


@pytest.mark.asyncio
async def test_context_modified_by_different_fixtures(
async_gen_fixture_set_var_1, async_gen_fixture_set_var_2, var, var_2
):
assert var.get() == 1
assert var_2.get() == 2
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
collect_ignore.append("async_fixtures/test_async_gen_fixtures_36.py")
collect_ignore.append("async_fixtures/test_nested_36.py")

if sys.version_info[:2] < (3, 7):
collect_ignore.append("async_fixtures/test_async_gen_fixtures_within_context_37.py")


@pytest.fixture
def dependent_fixture(event_loop):
Expand Down