In [1]:
# testing
from random import randint

In [2]:
import asyncio
import inspect
import logging
import paho.mqtt.client as mqtt
import queue
import threading
from threading import Event
from aiomqtt import Client as AsyncClient, MqttError, MqttCodeError
from paho.mqtt.client import Client as SyncClient
from paho.mqtt.client import CallbackAPIVersion, MQTTErrorCode
from queue import Queue
from uuid import uuid4

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

### Dependencies

In [3]:
class Formatter:
    def __init__(self):
        pass

    def pack(self, content):
        return content

    def unpack(self, content):
        return content

In [4]:
class ClientGenerator:
    """
    ClientGenerator establishes a common API for creating either
    a synchronous MQTT client (paho.mqtt.client.Client) or an
    asynchronous client (aiomqtt.client.Client).

    An appropriate client is created depending on the context in
    which the client operates, e.g.

    .. code: python

        with ClientGenerator() as client:
            # Produces a synchronous client
            client.connect(host="localhost", port=1883)
            client.publish(...)

        async with ClientGenerator(host="localhost", port=1883) as client:
            # Produces an asynchronous client
            client.publish(...)
    """
    def __init__(self, credentials=None, **kwargs):
        self.kwargs = kwargs
        self.credentials = credentials
        self.kwargs.setdefault("client_id", str(uuid4()))
        self.kwargs.setdefault("callback_api_version", CallbackAPIVersion.VERSION2)

    def __enter__(self) -> SyncClient:
        logger.debug("Entering ClientGenerator context.")
        kwargs = {k:v for k,v in self.kwargs.items()
                  if k in inspect.signature(SyncClient).parameters}
        client = SyncClient(**kwargs)
        if self.credentials:
            client = self.credentials.authenticate(client)
        logger.debug("ClientGenerator context setup complete.")
        return client

    def __exit__(self, *exc):
        logger.debug("Exited ClientGenerator context.")

    async def __aenter__(self) -> AsyncClient:
        logger.debug("Entering async ClientGenerator context.")
        kwargs = {k:v for k,v in self.kwargs.items()
                  if k in inspect.signature(AsyncClient).parameters}
        client = AsyncClient(**kwargs)
        if self.credentials:
            client = self.credentials.authenticate(client)
        logger.debug("Async ClientGenerator context setup complete.")
        return client

    async def __aexit__(self, *exc):
        logger.debug("Exited async ClientGenerator context.")

In [5]:
class Connection:
    def __init__(self, gen: ClientGenerator | None=None, **kwargs):
        self.client_generator: ClientGenerator = gen or ClientGenerator()
        self.client: AsyncClient | SyncClient | None = None
        self.event: Event = Event()
        self.kwargs = kwargs
        self.kwargs.setdefault("host", "localhost")
        self.kwargs.setdefault("port", 1883)
        self.callbacks = {
            "on_connect": self.on_connect,
            "on_disconnect": self.on_disconnect,
        }

    def __enter__(self) -> SyncClient:
        logger.debug("Entering Connection context.")
        if not self.client:
            # Reuse the client
            self.client = self.client_generator.__enter__()
        for k,v in self.callbacks.items():
            logger.debug(f"Registering {k} callback: {v}")
            setattr(self.client, k, v)
        try:
            self.event.clear()
            wait = 1
            while True:
                try:
                    self.client.connect(**self.kwargs)
                    self.client.loop_stop()
                    self.client.loop_start()
                    logger.debug("Waiting to establish connection.")
                    self.event.wait()
                except ConnectionRefusedError:
                    logger.info(f"Broker unavailable. Retry in {wait} seconds.")
                    self.event.wait(wait)
                    wait = min(2*wait, 120)
                except:
                    raise
                else:
                    logger.info("Connection established.")
                    break
        except KeyboardInterrupt:
            logger.debug("Attempt to connect cancelled.")
            raise
        except Exception as e:
            logger.error(f"Failed to connect {self.client} with {self.kwargs}: {e}")
            raise
        logger.debug("Connection context setup complete.")
        return self.client

    def __exit__(self, *exc):
        self.client.disconnect()
        self.client.loop_stop()
        self.client_generator.__exit__()
        logger.debug("Exited Connection context.")

    async def __aenter__(self) -> AsyncClient:
        logger.debug("Entering async Connection context.")
        # Aiomqtt Client combines client and connection.
        kwargs = {**self.client_generator.kwargs, **self.kwargs}
        # Handle API inconsistency between Paho Client and Aiomqtt Client.
        kwargs["hostname"] = kwargs.pop("host")
        self.client_generator.kwargs.update(kwargs)
        # Generate and connect client.
        if self.client is None:
            self.client = await self.client_generator.__aenter__()
        else:
            await self.__aexit__(None, None, None)
            logger.debug("Reentered async Connection context.")
        wait = 1
        while True:
            try:
                await self.client.__aenter__()
            except (ConnectionRefusedError, MqttError):
                logger.info(f"Broker not available. Attempting to reconnect in {wait} seconds.")
                await asyncio.sleep(wait)
                wait = min(2*wait, 120)
            except KeyboardInterrupt as e:
                logger.debug("Attempt to connect cancelled.")
                await self.client.__aexit__(type(e), e, None)
                raise
            except Exception as e:
                logger.error(f"An unexpected error occured while connecting ({e}). Aborting.")
                await self.client.__aexit__(type(e), e, None)
                raise
            else:
                logger.info(f"Connected to {kwargs['hostname']}:{kwargs['port']}")
                break
        logger.debug("Async Connection context complete.")
        return self.client

    async def __aexit__(self, *exc):
        client = self.client
        try:
            await client.__aexit__(*exc)
        except MqttError:
            # Try to gracefully disconnect from the broker
            rc = client._client.disconnect()
            if rc == mqtt.MQTT_ERR_SUCCESS:
                # Wait for acknowledgement
                await client._wait_for(client._disconnected, timeout=None)
                # Reset `_connected` if it's still in completed state after disconnecting
                if client._connected.done():
                    client._connected = asyncio.Future()
            else:
                logger.warning(
                    "Could not gracefully disconnect: %d. Forcing disconnection.", rc
                )
            # Force disconnection if we cannot gracefully disconnect
            if not client._disconnected.done():
                client._disconnected.set_result(None)
            # Release the reusability lock
            if client._lock.locked():
                client._lock.release()
        await self.client_generator.__aexit__(*exc)
        logger.debug("Exited async Connection context.")

    def on_connect(self, client, userdata, connect_flags, reason_code, properties):
        logger.debug("Starting connection loop.")
        self.event.set()

    def on_disconnect(self, client, userdata, disconnect_flags, reason_code, properties):
        # logger.debug("Stopping connection loop.")
        # client.loop_stop()
        pass

### Publish

#### Code

In [6]:
class Publisher:
    class Task:
        def __init__(self, scope: "Publisher", function: callable, *, timeout: float=10.0):
            self.scope = scope
            self.fn = function
            self.timeout: float = timeout
            self._lock: asyncio.Lock = asyncio.Lock()

    class AsyncTask(Task):
        def __init__(self, *args, **kwds):
            super().__init__(*args, **kwds)
            
        async def __call__(self, *args, **kwds):
            logger.debug(f"Calling {self.fn.__name__}.")
            result = await self.fn(*args, **kwds)
            scope = self.scope
            try:
                async with self._lock:
                    async with scope as client:
                        await client.publish(scope.topic, payload=scope.formatter.pack(result), **scope.kwargs)
                    logger.debug(f"Published {result}.")
            except Exception as e:
                logger.warning(f"Failed to publish {result} ({type(result)}) to {scope.topic}: {e}")
            return result
            
    class SyncTask(Task):
        def __init__(self, *args, **kwds):
            super().__init__(*args, **kwds)

        def __call__(self, *args, **kwds):
            logger.debug(f"Calling {self.fn.__name__}.")
            result = self.fn(*args, **kwds)
            scope = self.scope
            try:
                with scope as client:
                    info = client.publish(scope.topic, scope.formatter.pack(result), **scope.kwargs)
                    info.wait_for_publish(self.timeout)
                logger.debug(f"Published {result}.")
            except Exception as e:
                logger.warning(f"Failed to publish {result} to {scope.topic}: {e}")
            return result
            
    def __init__(self, topic: str, *, connection: Connection | None=None, formatter: Formatter | None=None, **kwargs):
        self.connection: Connection = connection or Connection()
        self.topic: str = topic
        self.kwargs = kwargs
        self.formatter = formatter or Formatter()

    def __enter__(self) -> SyncClient:
        logger.debug("Entering Publisher context.")
        client = self.connection.__enter__()
        logger.debug("Publisher context setup complete.")
        return client

    def __exit__(self, *exc):
        self.connection.__exit__()
        logger.debug("Exited Publisher context.")

    async def __aenter__(self) -> AsyncClient:
        logger.debug("Entering async Publisher context.")
        client = await self.connection.__aenter__()
        logger.debug("Async Publisher context setup complete.")
        return client

    async def __aexit__(self, *exc):
        await self.connection.__aexit__(*exc)
        logger.debug("Exited async Publisher context.")

    def __call__(self, fn):
        if inspect.iscoroutinefunction(fn):
            logger.debug(f"Wrapping coroutine function '{fn.__name__}'.")
            return Publisher.AsyncTask(self, fn)
        else:
            logger.debug(f"Wrapping function '{fn.__name__}'.")
            return Publisher.SyncTask(self, fn)

publish = Publisher

#### Test

In [7]:
@publish("test")
def hello():
    return "Hello, World!"

In [8]:
hello()

2025-03-07 07:20:51,693 - INFO - Connection established.


'Hello, World!'

In [9]:
@publish("test")
async def async_hello():
    await asyncio.sleep(1)
    return f"Await hello #{randint(1, 100):02d}, World!"

In [10]:
await async_hello()

2025-03-07 07:20:52,705 - INFO - Connected to localhost:1883


'Await hello #01, World!'

In [11]:
# Coroutine function should be run concurrently and published sequentially.
await asyncio.gather(
    async_hello(),
    async_hello(),
    async_hello(),
)

2025-03-07 07:20:53,717 - INFO - Connected to localhost:1883
2025-03-07 07:20:53,726 - INFO - Connected to localhost:1883
2025-03-07 07:20:53,734 - INFO - Connected to localhost:1883


['Await hello #24, World!',
 'Await hello #62, World!',
 'Await hello #04, World!']

### Subscribe

#### Code

In [12]:
class Subscriber:
    class Task:
        def __init__(self, scope: "Publisher", function: callable, *, timeout: float=10.0):
            self.scope = scope
            self.fn = function
            self.timeout: float = timeout
            self._lock: asyncio.Lock = asyncio.Lock()

    class AsyncTask(Task):
        def __init__(self, *args, **kwds):
            super().__init__(*args, **kwds)
            self.client = None

        def __aiter__(self):
            return self

        async def __anext__(self):
            # Context is reentrant.
            client = await self.__aenter__()
            # await client.subscribe(self.scope.topic, **self.scope.kwargs)
            try:
                try:
                    message = await anext(client.messages)
                except asyncio.CancelledError:
                    logger.debug("Future or Task was cancelled.")
                    raise StopAsyncIteration()
                except KeyboardInterrupt:
                    logger.debug("Iteration halted with Keyboard Interrupt.")
                    raise StopAsyncIteration()
                except (MqttError, MqttCodeError) as e:
                    logger.info("Disconnected from the broker. Waiting to reconnect.")
                    # raise StopAsyncIteration()
                    # async with self._lock:
                    await self.__aexit__(type(e), e, None)
                    client = await self.__aenter__()
                    logger.info("Reconnected to the broker.")
                    return await anext(self)
                except Exception as e:
                    logger.debug(f"An unknown error occured: {e}")
                else:
                    return await self.fn(self.scope.formatter.unpack(message.payload))
            except Exception as e:
                await self.__aexit__(type(e), e, None)
                raise e

        async def __aenter__(self) -> AsyncClient:
            async with self._lock:
                self.client = await self.scope.__aenter__()
                await self.client.subscribe(**self.scope.kwargs)
                logger.debug("Async Task context setup complete.")
            return self.client

        async def __aexit__(self, *exc):
            async with self._lock:
                if self.client:
                    logger.debug("Tearing down async Task context.")
                    await self.scope.__aexit__(*exc)
                    # self.client = None
                    logger.debug("Exited async Task context.")

        async def __call__(self):
            return await anext(self)
            
    class SyncTask(Task):
        def __init__(self, *args, **kwds):
            super().__init__(*args, **kwds)
            self.scope.connection.callbacks["on_connect"] = self.on_connect
            self.scope.connection.callbacks["on_message"] = self.on_message
            self.client = None
            self.queue: Queue = Queue()
            self._lock: threading.Lock = threading.Lock()

        def __iter__(self):
            return self

        def __next__(self):
            if not self.client:
                client = self.__enter__()
            try:
                return self.queue.get()
            except KeyboardInterrupt:
                logger.debug("Keyboard interrupt stopped Subscriber iteration.")
                raise StopIteration()
            except TimeoutError:
                logger.debug("A timeout interrupted Subscriber iteration.")
                raise StopIteration()
            except Exception as e:
                logger.error(f"An unexpected exception occured during Subscriber iteration: {e}.")
                self.__exit__(type(e), e, None)
                self.__enter__()
                return next(self)

        def __enter__(self):
            # Context is reentrant
            with self._lock:
                if self.client is None:
                    logger.debug("Entering Subscriber task context.")
                    self.client = self.scope.__enter__()
            logger.debug("Subscriber task context setup complete.")
            return self.client

        def __exit__(self, *exc):
            with self._lock:
                if self.client:
                    logger.debug("Tearing down Subscriber task context.")
                    self.scope.__exit__(*exc)
                    self.client = None
            logger.debug("Exited Subscriber task context.")

        def __call__(self):
            logger.debug("Called Subscriber.Task.__call__")
            return next(self)

        def on_message(self, client, userdata, message):
            logger.debug("Calling Subscriber.Task.on_message callback.")
            result = self.fn(self.scope.formatter.unpack(message.payload))
            self.queue.put(result)

        def on_connect(self, client, userdata, connect_flags, reason_code, properties):
            logger.debug("Calling Subscriber.Task.on_connect callback.")
            self.scope.connection.on_connect(client, userdata, connect_flags, reason_code, properties)
            client.subscribe(**self.scope.kwargs)
            
    def __init__(self, topic: str, qos: int=0, options=None, properties=None,
                 *,
                 connection: Connection | None=None,
                 formatter: Formatter | None=None,
                 **kwargs):
        self.connection: Connection = connection or Connection(**kwargs)
        self.kwargs = {
            "topic": topic,
            "qos": qos,
            "options": options,
            "properties": properties
        }
        self.formatter = formatter or Formatter()

    def __enter__(self) -> SyncClient:
        logger.debug("Entering Subscriber context.")
        client = self.connection.__enter__()
        logger.debug("Subscriber context setup complete.")
        return client

    def __exit__(self, *exc):
        self.connection.__exit__()
        logger.debug("Exited Subscriber context.")

    async def __aenter__(self) -> AsyncClient:
        logger.debug("Entering async Subscriber context.")
        client = await self.connection.__aenter__()
        logger.debug("Async Subscriber context setup complete.")
        return client

    async def __aexit__(self, *exc):
        await self.connection.__aexit__(*exc)
        logger.debug("Exited async Subscriber context.")

    def __call__(self, fn):
        if inspect.iscoroutinefunction(fn):
            logger.debug(f"Wrapping coroutine function '{fn.__name__}'.")
            return Subscriber.AsyncTask(self, fn)
        else:
            logger.debug(f"Wrapping function '{fn.__name__}'.")
            return Subscriber.SyncTask(self, fn)

subscribe = Subscriber

#### Test

In [13]:
@subscribe("test", host="localhost", port=1883)
def sync_sub(msg):
    return f"{msg} ({randint(1, 100):02})"

@subscribe("test", host="localhost", port=1883)
async def async_sub(msg):
    return f"{msg} ({randint(1, 100):02}, async)" 

In [14]:
print(sync_sub())

2025-03-07 07:20:53,765 - INFO - Connection established.


b'Hello' (21)


In [15]:
for message in sync_sub:
    print(message)

b'Hello' (39)
b'Hello' (63)
b'Hello' (42)
b'Hello' (56)


In [16]:
await async_sub()

2025-03-07 07:21:08,498 - INFO - Connected to localhost:1883


"b'Hello' (40, async)"

In [17]:
# await asyncio.gather(async_sub(), async_sub(), async_sub())

In [18]:
[x async for x in async_sub]

2025-03-07 07:21:13,175 - INFO - Connected to localhost:1883
2025-03-07 07:21:18,709 - INFO - Connected to localhost:1883
2025-03-07 07:21:19,668 - INFO - Connected to localhost:1883
2025-03-07 07:21:20,586 - INFO - Connected to localhost:1883
2025-03-07 07:21:21,679 - INFO - Connected to localhost:1883


["b'Hello' (66, async)",
 "b'Hello' (88, async)",
 "b'Hello' (95, async)",
 "b'Hello' (73, async)"]

In [19]:
async with asyncio.timeout(10):
    print([x async for x in async_sub])

2025-03-07 07:21:24,693 - INFO - Connected to localhost:1883
2025-03-07 07:21:29,738 - INFO - Connected to localhost:1883
2025-03-07 07:21:30,230 - INFO - Connected to localhost:1883
2025-03-07 07:21:30,626 - INFO - Connected to localhost:1883


["b'Hello' (43, async)", "b'Hello' (22, async)", "b'Hello' (58, async)"]


In [20]:
_ = input("Stop the broker. Press [enter] to continue. Restart the broker.")
[x async for x in async_sub]

Stop the broker. Press [enter] to continue. Restart the broker. 


2025-03-07 07:21:54,469 - INFO - Broker not available. Attempting to reconnect in 1 seconds.
2025-03-07 07:21:55,474 - INFO - Broker not available. Attempting to reconnect in 2 seconds.
2025-03-07 07:21:57,478 - INFO - Broker not available. Attempting to reconnect in 4 seconds.
2025-03-07 07:22:01,483 - INFO - Connected to localhost:1883
2025-03-07 07:22:03,211 - INFO - Connected to localhost:1883
2025-03-07 07:22:04,239 - INFO - Connected to localhost:1883
2025-03-07 07:22:05,095 - INFO - Connected to localhost:1883
2025-03-07 07:22:05,882 - INFO - Connected to localhost:1883
2025-03-07 07:22:08,008 - INFO - Disconnected from the broker. Waiting to reconnect.
2025-03-07 07:22:08,022 - INFO - Broker not available. Attempting to reconnect in 1 seconds.
2025-03-07 07:22:09,026 - INFO - Broker not available. Attempting to reconnect in 2 seconds.
2025-03-07 07:22:11,030 - INFO - Broker not available. Attempting to reconnect in 4 seconds.
2025-03-07 07:22:15,034 - INFO - Connected to localh

["b'Hello' (96, async)",
 "b'Hello' (80, async)",
 "b'Hello' (61, async)",
 "b'Hello' (45, async)",
 "b'Hello' (01, async)",
 "b'Hello' (43, async)",
 "b'Hello' (63, async)",
 "b'Hello' (49, async)"]

In [21]:
await async_sub()

2025-03-07 07:22:22,158 - INFO - Connected to localhost:1883
2025-03-07 07:22:33,673 - INFO - Disconnected from the broker. Waiting to reconnect.
2025-03-07 07:22:33,683 - INFO - Broker not available. Attempting to reconnect in 1 seconds.
2025-03-07 07:22:34,690 - INFO - Broker not available. Attempting to reconnect in 2 seconds.
2025-03-07 07:22:36,694 - INFO - Broker not available. Attempting to reconnect in 4 seconds.
2025-03-07 07:22:40,699 - INFO - Connected to localhost:1883
2025-03-07 07:22:40,703 - INFO - Reconnected to the broker.
2025-03-07 07:22:40,707 - INFO - Connected to localhost:1883


"b'Hello' (32, async)"