In [48]:
import os
import asyncio
import json
import logging
import time

# from collections import deque
from functools import wraps
from paho.mqtt import client as pclient, publish as ppub, subscribe as psub
from pydantic import BaseModel
from queue import Queue, Full, Empty
from typing import Any, Callable, Generator
from uuid import uuid4


logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("DEBUG_LEVEL", logging.INFO))

# README

This notebook works in concert with a local MQTT server, like mosquitto. The following instructions assume you are using mosquitto, but other implementations will have equivalent functions.

Start the mosquitto server: `mosquitto`

Other instructions are provided where they're necessary.

In [49]:
# Note: I have not tested whether reconnect is necessary or if that is handled
# automatically when the client is created.

class MQTTConnection:
    FIRST_RECONNECT_DELAY = int(os.environ.get("FIRST_RECONNECT_DELAY", 1))
    RECONNECT_RATE = int(os.environ.get("RECONNECT_RATE", 2)) # scales subsequent reconnect attempts
    MAX_RECONNECT_COUNT = int(os.environ.get("MAX_RECONNECT_COUNT", 12))
    MAX_RECONNECT_DELAY = int(os.environ.get("MAX_RECONNECT_DELAY", 60))

    class ConnectInfo(BaseModel):
        userdata: Any
        flags: Any
        return_code: Any
        properties: Any

    def __init__(self, topic: str, *, host: str="localhost", port: int=1883):
        self.topic: str = str(topic)
        self.host: str = str(host)
        self.port: int = int(port)
        self.client: pclient.Client = pclient.Client(
            callback_api_version=pclient.CallbackAPIVersion.VERSION2,
            client_id=str(uuid4()),
        )
        self.connect_info: None | MQTTConnection.ConnectInfo = None

    def __del__(self):
        self.disconnect()

    def connect(self):
        if self.client.is_connected():
            return
        # self.client.on_connect = self._on_connect
        # self.client.on_disconnect = self._attempt_reconnect
        self.client.connect(self.host, self.port)

    def reconnect(self):
        self.disconnect()
        self.connect()

    def disconnect(self):
        # self.client.on_disconnect = self._disconnect
        if self.client.is_connected():
            self.client.disconnect()
            self.connect_info = None

    def _on_connect(self, client, userdata, flags, rc, properties) -> None:
        if not rc:
            self.connect_info = MQTTConnection.ConnectInfo(
                userdata=userdata,
                flags=flags,
                return_code=rc,
                properties=properties,
            )
            logger.info(f"Connected {userdata} to {self.host}:{self.port}")

    def _attempt_reconnect(self, client, userdata, rc, properties) -> None:
        cls = type(self)
        logger.info(f"Disconnected MQTT connection with return code {rc}.")
        reconnect_count, reconnect_delay = 0, cls.FIRST_RECONNECT_DELAY
        while reconnect_count < cls.MAX_RECONNECT_COUNT:
            logger.info(f"Reconnecting in {reconnect_delay} seconds...")
            time.sleep(reconnect_delay)
            try:
                self.client.reconnect()
                logger.info("Reconnect successful")
                return
            except Exception as err:
                logger.error(f"({str(err)}) Reconnect failed. Retrying...")

            reconnect_delay *= cls.RECONNECT_RATE
            reconnect_delay = min(reconnect_delay, cls.MAX_RECONNECT_DELAY)
            reconnect_count += 1
        self.connect_info = None
        logging.info(f"Reconnect failed after {reconnect_count} attempts. "
                     f"Exiting...")

    def _disconnect(self, client, userdata, rc, properties) -> None:
        logging.info(f"Disconnected MQTT connection with return code {rc}.")

# Pub-Sub

In [50]:
class PubSubStopIteration(StopIteration, StopAsyncIteration):
    pass

## Publisher

Start the MQTT subscriber: `mosquitto_sub -t '#'`

In [51]:
class SerializationError(Exception):
    pass

class PublishError(Exception):
    pass

class PublishSuccess(Exception):
    pass

class publish(MQTTConnection):
    FIRST_RETRY_DELAY = int(os.environ.get("FIRST_RETRY_DELAY", 1))
    RETRY_RATE = int(os.environ.get("RETRY_RATE", 2)) # scales subsequent reconnect attempts
    MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 12))
    MAX_RETRY_DELAY = int(os.environ.get("MAX_RETRY_DELAY", 60))
    
    def __init__(self,
                 topic: str, *, host: str="localhost", port: int=1883,  # MQTT
                 serializer: None | Callable=None,
                 lazy: bool=True,):
        """
        Decorator class that publishes the result of the decorated function
        to an MQTT broker.

        Parameters
        ----------
        topic : str
            The topic to which the result will be published.
        host : str
            The host to which the result will be published. Default: localhost.
        port : int
            The port through which the result will be published. Default: 1883
        serializer : Callable
            A function that is applied to the result of the function before
            publishing. This typically converts from an object to a string
            representation. Default: json.dumps.
        lazy : bool
            If True (default), the MQTT connection is only made the first time
            the wrapped function is called. Otherwise the connection is
            established immediately.
        """
        super().__init__(topic, host=host, port=port)
        self.serializer = self._default_serializer
        if not lazy:
            self.connect()

    def _default_serializer(self, content: Any) -> Any:
        # TODO: This is very presumptuous. While json.dumps is a high
        # probability choice and matches the default in subscribe, it is far
        # from the only choice. The default serializer should try a
        # number of serializers in a reasonable order, ultimately
        # defaulting to one guaranteed to work, even if the result must be
        # processed by the recipient. For example: pydantic.dump_model_json ->
        # json.dumps -> str(content).
        try:
            return content.model_dump_json()
        except:
            logger.debug("Serialization content is not pydantic.BaseModel.")
        try:
            return json.dumps(content)
        except:
            logger.debug("Serialization content is not JSON serializable.")
        return str(content)

    def __call__(self, func: Callable) -> Callable:
        cls = type(self)
        @wraps(func)
        def wrapper(*args, **kwds):
            # Connects to the MQTT server lazily.
            self.connect()
            result = func(*args, **kwds)
            # Attempt to publish to the MQTT server
            try:
                retry_count, retry_delay = 0, cls.FIRST_RETRY_DELAY
                while retry_count < cls.MAX_RETRY_COUNT:
                    rc, mid = self.client.publish(self.topic, self.serializer(result))
                    if rc == 0:
                        raise PublishSuccess()
                    logger.info(f"Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
        
                    retry_delay *= cls.RETRY_RATE
                    retry_delay = min(retry_delay, cls.MAX_RETRY_DELAY)
                    retry_count += 1
                raise PublishError()
            except PublishSuccess:
                pass
            except PublishError:
                logger.error(f"Publishing message {mid!r} failed with status "
                             f"{rc} after {retry_count} retries.")
            # Finished
            return result
        return wrapper

In [52]:
@publish("test")
def add(lhs, rhs):
    return lhs + rhs

In [53]:
add(1, 2)

3

In [54]:
add(12, 34)

46

In [55]:
del add

## Subscriber

Subscribers will respond to events. Each should be published to the "test" topic, e.g.

`mosquitto_pub -t test -m 25`

(You may wish to substitute other numbers.)

In [56]:
class DeserializationError(Exception):
    pass

class TimeoutException(Exception):
    pass
    

class subscribe:
    class Task(MQTTConnection):
        POLLING_INTERVAL = float(os.environ.get("POLLING_INTERVAL", 0.1))
        POLLING_TIMEOUT = float(os.environ.get("POLLING_TIMEOUT", 30))  # polling timeout (seconds)
        
        def __init__(self,
                     topic: str, *, host: str="localhost", port: int=1883,
                     deserializer: None | Callable=None):
            super().__init__(topic, host=host, port=port)
            self.deserializer: Callable = deserializer or self._default_deserializer
            self.result_queue: Queue = Queue()
            self.func: None | Callable = None

        def connect(self):
            self.client.on_message = self._on_message
            super().connect()
            self.client.subscribe(self.topic)
    
        def _on_message(self, client, userdata, msg) -> Any:
            output = self.func(self.deserializer(msg.payload.decode()))
            self.result_queue.put(output)
    
        def _default_deserializer(self, message: str) -> Any:
            try:
                return json.loads(message)
            except:
                logger.debug("Message is not JSON deserializable.")
            return message

        def _poll(self):
            cls = type(self)
            duration, polling_interval = 0.0, cls.POLLING_INTERVAL
            while True:
                try:
                    return self.result_queue.get()
                except Empty:
                    time.sleep(polling_interval)
                    duration += polling_interval
                    if int(duration - polling_interval) != int(duration):
                        print(f"Duration {int(duration)} seconds")
                    if duration > cls.POLLING_TIMEOUT:
                        raise StopIteration()

        def __del__(self):
            self.disconnect()

        def __iter__(self):
            return self

        def __next__(self):
            if not self.client.is_connected():
                self.connect()
                self.client.loop_start()
            try:
                return self._poll()
            except (StopIteration, KeyboardInterrupt):
                self.client.loop_stop()
                self.disconnect()
                raise StopIteration()

        def __aiter__(self):
            return self

        async def __anext__(self):
            if not self.client.is_connected():
                self.connect()
                self.client.loop_start()
            try:
                return self._poll()
            except (StopIteration, KeyboardInterrupt):
                self.client.loop_stop()
                self.client.disconnect()
                return StopAsyncIteration()

        def __call__(self):
            self.connect()
            self.client.loop_start()
            try:
                return self._poll()
            except StopIteration:
                raise TimeoutException(f"No messages received in {type(self).POLLING_TIMEOUT} seconds.")
            finally:
                self.client.loop_stop()
                self.disconnect()

        
    def __init__(self,
                 topic: str, *, host: str="localhost", port: int=1883,  # MQTT
                 deserializer: None | Callable=None,):
        """
        Decorator that subsribes to a broker and applies that message as the
        first positional parameter of the wrapped function.

        Parameters
        ----------
        deserializer : Callable
            Processes the incoming message before assigning the value as the first
            positional parameter of the wrapped function. Default: json.loads.
        """
        self.task: subscribe.Task = subscribe.Task(
            topic, host=host, port=port,
            deserializer=deserializer,)

    def __call__(self, func: None | Callable=None) -> Any:
        if func is not None:
            self.task.func = func
        return self.task

In [59]:
@subscribe("test")
def double(x):
    return 2*x

In [61]:
# Processes one message and exits: mosquitto_pub -t -m 25
double()

50

In [62]:
# Processes messages until (a) KeyboardInterrupt ("Stop" icon in the action
# bar above) or (b) polling timeout (default: 30 seconds).
[x for x in double]

[50, 50, 50, 50]