diff --git a/docs/index.rst b/docs/index.rst index 9c733bb..c2f111a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -272,6 +272,69 @@ See the documentation of the :obj:`receiver_connected` built-in signal for an example. +Async receivers +--------------- + +Receivers can be coroutine functions which can be called and awaited +via the :meth:`~Signal.send_async` method: + +.. code-block:: python + + sig = blinker.Signal() + + async def receiver(): + ... + + sig.connect(receiver) + await sig.send_async() + +This however requires that all receivers are awaitable which then +precludes the usage of :meth:`~Signal.send`. To mix and match the +:meth:`~Signal.send_async` method takes a ``_sync_wrapper`` argument +such as: + +.. code-block:: python + + sig = blinker.Signal() + + def receiver(): + ... + + sig.connect(receiver) + + def wrapper(func): + + async def inner(*args, **kwargs): + func(*args, **kwargs) + + return inner + + await sig.send_async(_sync_wrapper=wrapper) + +The equivalent usage for :meth:`~Signal.send` is via the +``_async_wrapper`` argument. This usage is will depend on your event +loop, and in the simple case whereby you aren't running within an +event loop the following example can be used: + +.. code-block:: python + + sig = blinker.Signal() + + async def receiver(): + ... + + sig.connect(receiver) + + def wrapper(func): + + def inner(*args, **kwargs): + asyncio.run(func(*args, **kwargs)) + + return inner + + await sig.send(_async_wrapper=wrapper) + + API Documentation ----------------- diff --git a/pyproject.toml b/pyproject.toml index 98a0b20..65e5e5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,5 +44,6 @@ include-package-data = false version = {attr = "blinker.__version__"} [tool.pytest.ini_options] +asyncio_mode = "auto" testpaths = ["tests"] filterwarnings = ["error"] diff --git a/requirements/tests.in b/requirements/tests.in index e079f8a..ee4ba01 100644 --- a/requirements/tests.in +++ b/requirements/tests.in @@ -1 +1,2 @@ pytest +pytest-asyncio diff --git a/requirements/tests.txt b/requirements/tests.txt index e53d95b..edcf4fb 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ -# SHA1:0eaa389e1fdb3a1917c0f987514bd561be5718ee +# SHA1:738f1ea95febe383951f6eb64bdad13fefc1a97a # # This file is autogenerated by pip-compile-multi # To update, run: @@ -20,10 +20,16 @@ packaging==23.0 pluggy==1.0.0 # via pytest pytest==7.2.1 + # via + # -r requirements/tests.in + # pytest-asyncio +pytest-asyncio==0.20.3 # via -r requirements/tests.in tomli==2.0.1 # via pytest typing-extensions==4.4.0 - # via importlib-metadata + # via + # importlib-metadata + # pytest-asyncio zipp==3.11.0 # via importlib-metadata diff --git a/src/blinker/_utilities.py b/src/blinker/_utilities.py index 68e8422..22beb81 100644 --- a/src/blinker/_utilities.py +++ b/src/blinker/_utilities.py @@ -1,3 +1,8 @@ +import asyncio +import inspect +import sys +from functools import partial +from typing import Any from weakref import ref from blinker._saferef import BoundMethodWeakref @@ -93,3 +98,35 @@ def __get__(self, obj, cls): value = self._deferred(obj) setattr(obj, self._deferred.__name__, value) return value + + +def is_coroutine_function(func: Any) -> bool: + # Python < 3.8 does not correctly determine partially wrapped + # coroutine functions are coroutine functions, hence the need for + # this to exist. Code taken from CPython. + if sys.version_info >= (3, 8): + return asyncio.iscoroutinefunction(func) + else: + # Note that there is something special about the AsyncMock + # such that it isn't determined as a coroutine function + # without an explicit check. + try: + from unittest.mock import AsyncMock + + if isinstance(func, AsyncMock): + return True + except ImportError: + # Not testing, no asynctest to import + pass + + while inspect.ismethod(func): + func = func.__func__ + while isinstance(func, partial): + func = func.func + if not inspect.isfunction(func): + return False + result = bool(func.__code__.co_flags & inspect.CO_COROUTINE) + return ( + result + or getattr(func, "_is_coroutine", None) is asyncio.coroutines._is_coroutine + ) diff --git a/src/blinker/base.py b/src/blinker/base.py index f80750c..8d41721 100644 --- a/src/blinker/base.py +++ b/src/blinker/base.py @@ -13,6 +13,7 @@ from weakref import WeakValueDictionary from blinker._utilities import hashable_identity +from blinker._utilities import is_coroutine_function from blinker._utilities import lazy_property from blinker._utilities import reference from blinker._utilities import symbol @@ -242,7 +243,7 @@ def temporarily_connected_to(self, receiver, sender=ANY): ) return self.connected_to(receiver, sender) - def send(self, *sender, **kwargs): + def send(self, *sender, _async_wrapper=None, **kwargs): """Emit this signal on behalf of *sender*, passing on ``kwargs``. Returns a list of 2-tuples, pairing receivers with their return @@ -250,9 +251,51 @@ def send(self, *sender, **kwargs): :param sender: Any object or ``None``. If omitted, synonymous with ``None``. Only accepts one positional argument. + :param _async_wrapper: A callable that should wrap a coroutine + receiver and run it when called synchronously. :param kwargs: Data to be sent to receivers. """ + if self.is_muted: + return [] + + sender = self._extract_sender(sender) + results = [] + for receiver in self.receivers_for(sender): + if is_coroutine_function(receiver): + if _async_wrapper is None: + raise RuntimeError("Cannot send to a coroutine function") + receiver = _async_wrapper(receiver) + results.append((receiver, receiver(sender, **kwargs))) + return results + + async def send_async(self, *sender, _sync_wrapper=None, **kwargs): + """Emit this signal on behalf of *sender*, passing on ``kwargs``. + + Returns a list of 2-tuples, pairing receivers with their return + value. The ordering of receiver notification is undefined. + + :param sender: Any object or ``None``. If omitted, synonymous + with ``None``. Only accepts one positional argument. + :param _sync_wrapper: A callable that should wrap a synchronous + receiver and run it when awaited. + + :param kwargs: Data to be sent to receivers. + """ + if self.is_muted: + return [] + + sender = self._extract_sender(sender) + results = [] + for receiver in self.receivers_for(sender): + if not is_coroutine_function(receiver): + if _sync_wrapper is None: + raise RuntimeError("Cannot send to a non-coroutine function") + receiver = _sync_wrapper(receiver) + results.append((receiver, await receiver(sender, **kwargs))) + return results + + def _extract_sender(self, sender): if not self.receivers: # Ensure correct signature even on no-op sends, disable with -O # for lowest possible cost. @@ -273,14 +316,7 @@ def send(self, *sender, **kwargs): ) else: sender = sender[0] - - if self.is_muted: - return [] - else: - return [ - (receiver, receiver(sender, **kwargs)) - for receiver in self.receivers_for(sender) - ] + return sender def has_receivers_for(self, sender): """True if there is probably a receiver for *sender*. diff --git a/tests/test_signals.py b/tests/test_signals.py index 4f1114d..1fca3e5 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -332,6 +332,33 @@ def received(sender): assert [id(fn) for fn in sig.receivers.values()] == [fn_id] +async def test_async_receiver(): + sentinel = [] + + async def received_async(sender): + sentinel.append(sender) + + def received(sender): + sentinel.append(sender) + + def wrapper(func): + + async def inner(*args, **kwargs): + func(*args, **kwargs) + + return inner + + sig = blinker.Signal() + sig.connect(received) + sig.connect(received_async) + + await sig.send_async(_sync_wrapper=wrapper) + assert len(sentinel) == 2 + + with pytest.raises(RuntimeError): + sig.send() + + def test_instancemethod_receiver(): sentinel = []