Skip to content

Commit

Permalink
bpo-36996: Handle async functions when mock.patch is used as a decora…
Browse files Browse the repository at this point in the history
…tor (GH-13562)

Return a coroutine while patching async functions with a decorator. 

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>


https://bugs.python.org/issue36996
  • Loading branch information
tirkarthi authored and miss-islington committed May 28, 2019
1 parent 71dc7c5 commit 436c2b0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
84 changes: 57 additions & 27 deletions Lib/unittest/mock.py
Expand Up @@ -26,6 +26,7 @@
__version__ = '1.0'

import asyncio
import contextlib
import io
import inspect
import pprint
Expand Down Expand Up @@ -1220,6 +1221,8 @@ def copy(self):
def __call__(self, func):
if isinstance(func, type):
return self.decorate_class(func)
if inspect.iscoroutinefunction(func):
return self.decorate_async_callable(func)
return self.decorate_callable(func)


Expand All @@ -1237,41 +1240,68 @@ def decorate_class(self, klass):
return klass


@contextlib.contextmanager
def decoration_helper(self, patched, args, keywargs):
extra_args = []
entered_patchers = []
patching = None

exc_info = tuple()
try:
for patching in patched.patchings:
arg = patching.__enter__()
entered_patchers.append(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
extra_args.append(arg)

args += tuple(extra_args)
yield (args, keywargs)
except:
if (patching not in entered_patchers and
_is_started(patching)):
# the patcher may have been started, but an exception
# raised whilst entering one of its additional_patchers
entered_patchers.append(patching)
# Pass the exception to __exit__
exc_info = sys.exc_info()
# re-raise the exception
raise
finally:
for patching in reversed(entered_patchers):
patching.__exit__(*exc_info)


def decorate_callable(self, func):
# NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func

@wraps(func)
def patched(*args, **keywargs):
extra_args = []
entered_patchers = []
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return func(*newargs, **newkeywargs)

exc_info = tuple()
try:
for patching in patched.patchings:
arg = patching.__enter__()
entered_patchers.append(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
extra_args.append(arg)

args += tuple(extra_args)
return func(*args, **keywargs)
except:
if (patching not in entered_patchers and
_is_started(patching)):
# the patcher may have been started, but an exception
# raised whilst entering one of its additional_patchers
entered_patchers.append(patching)
# Pass the exception to __exit__
exc_info = sys.exc_info()
# re-raise the exception
raise
finally:
for patching in reversed(entered_patchers):
patching.__exit__(*exc_info)
patched.patchings = [self]
return patched


def decorate_async_callable(self, func):
# NB. Keep the method in sync with decorate_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func

@wraps(func)
async def patched(*args, **keywargs):
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return await func(*newargs, **newkeywargs)

patched.patchings = [self]
return patched
Expand Down
16 changes: 16 additions & 0 deletions Lib/unittest/test/testmock/testasync.py
Expand Up @@ -66,6 +66,14 @@ def test_async(mock_method):

test_async()

def test_async_def_patch(self):
@patch(f"{__name__}.async_func", AsyncMock())
async def test_async():
self.assertIsInstance(async_func, AsyncMock)

asyncio.run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func))


class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self):
Expand All @@ -91,6 +99,14 @@ def test_async():

test_async()

def test_async_def_cm(self):
async def test_async():
with patch(f"{__name__}.async_func", AsyncMock()):
self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func))

asyncio.run(test_async())


class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):
Expand Down
@@ -0,0 +1 @@
Handle :func:`unittest.mock.patch` used as a decorator on async functions.

0 comments on commit 436c2b0

Please sign in to comment.