Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disconnect cleanly #78

Closed
nbraun-wolf opened this issue Oct 3, 2021 · 7 comments
Closed

disconnect cleanly #78

nbraun-wolf opened this issue Oct 3, 2021 · 7 comments

Comments

@nbraun-wolf
Copy link

nbraun-wolf commented Oct 3, 2021

Hi, what is the intended way to stop cleanly? When calling client.disconnect, it throws an error MqttError("Disconnected during message iteration").

If there is no other way, how can we know if the user called disconnect on purpose? In the original paho client, we have reason code 0 for that. But I can't find a way to get this info from this wrapper here.

I am trying to run this client in a background thread, and using fastapi as main thread. I have adjusted the example for the readme a bit. I am using this loop boolean currently, but It's not pretty.

Client Code
from asyncio import CancelledError, create_task, gather, run, sleep
from contextlib import AsyncExitStack
from threading import Thread

from asyncio_mqtt import Client, MqttError

broker_host = "localhost"
topic_prefixes = ("test/#",)
topic_handlers = {}

client = None
loop = True


def topic(topic):
    def wrapper(handler):
        topic_handlers[topic] = handler
        return handler

    return wrapper


async def run_client():
    global client
    async with AsyncExitStack() as stack:
        tasks = set()

        async def cancel_tasks(tasks):
            for task in tasks:
                if task.done():
                    continue
                task.cancel()
                try:
                    await task
                except CancelledError:
                    pass

        async def handle_messages(client, messages, handler):
            async for message in messages:
                await handler(client, message.payload.decode())

        stack.push_async_callback(cancel_tasks, tasks)
        client = Client(broker_host)

        await stack.enter_async_context(client)

        for topic_filter, handler in topic_handlers.items():
            manager = client.filtered_messages(topic_filter)
            messages = await stack.enter_async_context(manager)
            task = create_task(handle_messages(client, messages, handler))
            tasks.add(task)

        for topic_prefix in topic_prefixes:
            await client.subscribe(topic_prefix)

        await gather(*tasks)


async def background_client():
    reconnect_interval = 3
    while loop:
        try:
            await run_client()
        except MqttError as error:
            if loop:
                print(f'Error "{error}"')
        finally:
            if loop:
                await sleep(reconnect_interval)


def _paho_thread():
    run(background_client())


def get_client():
    if not client:
        raise Exception("could not get client, did you forget to call mqtt_startup?")
    return client


paho_thread = Thread(target=_paho_thread, daemon=True)


async def mqtt_startup():
    paho_thread.start()


async def mqtt_shutdown():
    global loop
    client = get_client()
    loop = False
    await client.disconnect()
    paho_thread.join()
Fastapi Code
from json import dumps
from typing import Optional

from fastapi import FastAPI

from mqtt import get_client, mqtt_shutdown, mqtt_startup, topic

app = FastAPI()


@app.on_event("startup")
async def startup_event():
    await mqtt_startup()


@app.on_event("shutdown")
async def shutdown_event():
    await mqtt_shutdown()


@app.get("/")
async def read_root():
    client = get_client()
    await client.publish("test/bar", dumps({"pub": "foo"}))
    return {"Hello": "World"}


@topic("test/foo")
async def test(client, payload):
    print("foo handler")
    print(payload)
    await client.publish("test/bar", dumps({"pub": "foo"}))


@topic("test/bar")
async def foo(client, payload):
    print("bar handler")
    print(payload)
@frederikaalund
Copy link
Member

Hi nbraun-wolf, thanks for opening this issue.

The recommended way is to use Client as context manager in an async with statement. If you do so, you shouldn't call disconnect directly. It was a design blunder on my end to make disconnect a public method. It should be private method but I won't break the API just for that.

I can see that you use threads together with asyncio tasks. Why not just use the latter? This way you can call asyncio.Task.cancel instead of disconnect. Something like this:

_MQTT_CLIENT_TASK = None

async def mqtt_startup():
    assert _MQTT_CLIENT_TASK is None, "Don't call mqtt_startup twice"
    # Schedule the coroutine and return immediately. The coroutine continues in the background.
    _MQTT_CLIENT_TASK = asyncio.create_task(background_client())

async def mqtt_shutdown():
    assert _MQTT_CLIENT_TASK is not None, "Call mqtt_startup before mqtt_shutdown"
    # Cancel the client task
    _MQTT_CLIENT_TASK.cancel()
    # Wait until the task is done. This raises the task's exception (if any).
    try:
        await _MQTT_CLIENT_TASK
    # Ignore `CancelledError` since we know that the task got cancelled (we did so ourselves).
    except asyncio.CancelledError:
        pass

Does it make sense?

@nbraun-wolf
Copy link
Author

nbraun-wolf commented Oct 4, 2021

Hi, thanks for getting back on this. And also thanks for the suggestion. I have implemented this and it works well. I was thinking it would be better for the application performance to run this on a separate thread. Since now both http handlers and mqtt handler share the same event loop.

Client Wrapper and Subscriber Code
from asyncio import CancelledError, Task, create_task, gather, sleep
from contextlib import AsyncExitStack
from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple, Union

from asyncio_mqtt import Client, MqttError
from rapidjson import loads
from structlog import get_logger

logger = get_logger()

Handler = Callable[[Client, str, Any], Coroutine[Any, Any, None]]
Parse_Func = Callable[[Union[str, bytes]], Any]
Handler_Tuple = Tuple[Handler, Union[Parse_Func, None]]
JSON_Decoded = Union[Dict[str, Any], List[Any]]


class Subscriber:
    """Decorate topic handlers for topic tree prefixes passed into the constructor.
    Must register the subscriber with register_subscriber from a MQTTClientWrapper class."""

    _handlers: Dict[str, Handler_Tuple] = dict()

    def __init__(self, *topic_prefixes: str) -> None:
        self.topic_prefixes: Tuple[str, ...] = topic_prefixes

    def topic(self, topic: str, parse_func: Optional[Parse_Func] = loads) -> Callable[[Handler], Handler]:
        """Subscribe to mqtt topic. by default the response is parsed with rapidjson.loads and passed into the handler.
        Override this with a custom parser or set it to None to receive the raw response."""

        def wrapper(handler: Handler) -> Handler:
            self._handlers[topic] = (handler, parse_func)
            return handler

        return wrapper


class MQTTClientWrapper:
    """MQTTClientWrapper holds the client state and contains methods to start and stop the client.
    Must register subscribers with its register_subscriber method."""

    _subscribers: List[Subscriber] = []
    _client: Union[Client, None] = None
    _client_task: Union[Task[None], None] = None
    _disconnected: bool = False

    def __init__(self, hostname: str, port: int, client_id: str, reconnect_interval: int = 10) -> None:
        self.hostname = hostname
        self.port = port
        self.client_id = client_id
        self.reconnect_interval: int = reconnect_interval

    @property
    def client(self) -> Client:
        assert self._client is not None, "Client task is not running, did you forget to call start_loop?"
        return self._client

    def register_subscriber(self, subscriber: Subscriber) -> None:
        """Register a new subscriber."""
        self._subscribers.append(subscriber)

    async def start_loop(self) -> None:
        """Start async client task and put it on the main loop. It will connect to the mqtt broker and subscribe to
        the topic trees from the subscribers registered via register_subscriber. Attempts to reconnect based on the
        set reconnect interval (default is 10 seconds). Await stop_loop to stop it."""
        assert self._client_task is None, "Don't call mqtt_startup twice."
        self._client_task = create_task(self._async_mqtt_loop())

    async def stop_loop(self) -> None:
        """Stop the async client task. Disconnects from the mqtt broker."""
        assert self._client_task is not None, "Cannot call stop_loop without calling start_loop first."
        logger.debug("stopping mqtt client task")
        self._client_task.cancel()
        try:
            await self._client_task
        except CancelledError:
            pass

    async def connect_and_subscribe(self) -> None:
        """Connected to mqtt broker and use registered subsribers to subscribe to topic trees.
        Call topic handlers when a message with specific topic is received.
        This function ends when an error occurs, use loop_forever to reconnect on error."""
        async with AsyncExitStack() as stack:
            tasks: Set[Task[None]] = set()

            stack.push_async_callback(self._cancel_tasks, tasks)
            self._client = Client(hostname=self.hostname, port=self.port, client_id=self.client_id)
            await stack.enter_async_context(self._client)  # type: ignore
            logger.info("successfully connected to mqtt broker")
            self._disconnected = False

            for sub in self._subscribers:
                for topic_filter, handler_tuple in sub._handlers.items():
                    manager = self._client.filtered_messages(topic_filter)
                    messages = await stack.enter_async_context(manager)
                    task = create_task(self._handle_messages(self._client, messages, *handler_tuple))
                    tasks.add(task)

                for topic_prefix in sub.topic_prefixes:
                    logger.debug("subscribing to topic tree", topicPrefix=topic_prefix)
                    await self._client.subscribe(topic_prefix)

            await gather(*tasks)

    async def _handle_messages(
        self, client: Client, messages: Any, handler: Handler, parse_func: Optional[Parse_Func]
    ) -> None:
        async for message in messages:
            logger.debug("mqtt message", payload=message.payload, topic=message.topic)
            await handler(client, message.topic, parse_func(message.payload) if parse_func else message.playload)

    async def _cancel_tasks(self, tasks: Set[Task[None]]) -> None:
        for task in tasks:
            if task.done():
                continue
            task.cancel()
            try:
                await task
            except CancelledError:
                pass

    async def _async_mqtt_loop(self) -> None:
        while True:
            try:
                await self.connect_and_subscribe()
            except MqttError as error:
                if not self._disconnected:
                    logger.info("could not connect or lost the connection to mqtt broker", exc_info=error)
                    logger.debug(f"attempting to reconnect to mqtt broker every {self.reconnect_interval} seconds")
                    self._disconnected = True
            finally:
                await sleep(self.reconnect_interval)

@frederikaalund
Copy link
Member

frederikaalund commented Oct 4, 2021

Glad to hear that you got it to work. 👍 Can I close this issue now?

I was thinking it would be better for the application performance to run this on a separate thread. Since now both http handlers and mqtt handler share the same event loop.

I understand what you want to do here. I doubt that threads together with asyncio can give you that performance. In any case, profile first and optimize later. Do you actually have performance problems right now or is it purely a theoretical exercise? :)

@nbraun-wolf
Copy link
Author

Yes, we can close the issue.

No, I had no performance issue. This was basically just trying to be precautions.

@nbraun-wolf
Copy link
Author

nbraun-wolf commented Oct 19, 2021

@frederikaalund, do you have any idea why I can't publish messages like this when I have not registered any subscriber? I have a use case where I don't need to subscribe, but I want to use the same class as above.

It fails with error, I feel like it could be because the loop coroutine is already finished, since await gather(*tasks) finishes. But I am not sure. I was expecting the client task to keep running in the background.

  File "/test-mqtt/./wolf/my_namespace/my_package/routers/index.py", line 15, in index
    await mqtt.client.publish("sap/serial/new", dumps([1]))
  File "/test-mqtt/.venv/lib/python3.9/site-packages/asyncio_mqtt/client.py", line 237, in publish
    raise MqttCodeError(info.rc, "Could not publish message")
asyncio_mqtt.error.MqttCodeError: [code:4] Could not publish message

@frederikaalund
Copy link
Member

frederikaalund commented Oct 20, 2021

It fails with error, I feel like it could be because the loop coroutine is already finished, since await gather(*tasks) finishes. But I am not sure. I was expecting the client task to keep running in the background.

To test this hypothesis, try to add await sleep(100000) at the end. This way, the background task stays up.

@nbraun-wolf
Copy link
Author

nbraun-wolf commented Oct 23, 2021

@frederikaalund, I have experimented with this, and I am fairly sure this is it. I mean, it also makes sense given the fact that the client is basically started in the AsyncExitStack and only the task set is awaited with gather. But its empty.

async with AsyncExitStack() as stack:
    tasks: Set[Task[None]] = set()
    await stack.enter_async_context(self._client)
    await asyncio.gather(*tasks)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants