diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 371feb9f..bf9073ac 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -106,17 +106,18 @@ async def mqtt_subscribe(self, topics, packet_id): # Wait for SUBACK is received waiter = futures.Future() self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter - return_codes = await waiter - - del self._subscriptions_waiter[subscribe.variable_header.packet_id] + try: + return_codes = await waiter + finally: + del self._subscriptions_waiter[subscribe.variable_header.packet_id] return return_codes async def handle_suback(self, suback: SubackPacket): packet_id = suback.variable_header.packet_id - try: - waiter = self._subscriptions_waiter.get(packet_id) + waiter = self._subscriptions_waiter.get(packet_id) + if waiter is not None: waiter.set_result(suback.payload.return_codes) - except KeyError: + else: self.logger.warning( "Received SUBACK for unknown pending subscription with Id: %s" % packet_id @@ -132,15 +133,17 @@ async def mqtt_unsubscribe(self, topics, packet_id): await self._send_packet(unsubscribe) waiter = futures.Future() self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter - await waiter - del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] + try: + await waiter + finally: + del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] async def handle_unsuback(self, unsuback: UnsubackPacket): packet_id = unsuback.variable_header.packet_id - try: - waiter = self._unsubscriptions_waiter.get(packet_id) + waiter = self._unsubscriptions_waiter.get(packet_id) + if waiter is not None: waiter.set_result(None) - except KeyError: + else: self.logger.warning( "Received UNSUBACK for unknown pending subscription with Id: %s" % packet_id @@ -152,10 +155,12 @@ async def mqtt_disconnect(self): async def mqtt_ping(self): ping_packet = PingReqPacket() - await self._send_packet(ping_packet) - resp = await self._pingresp_queue.get() - if self._ping_task: - self._ping_task = None + try: + await self._send_packet(ping_packet) + resp = await self._pingresp_queue.get() + finally: + if self._ping_task: + self._ping_task = None return resp async def handle_pingresp(self, pingresp: PingRespPacket): diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 8eb5a45c..e7e04189 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -293,12 +293,13 @@ async def _handle_qos1_message_flow(self, app_message): # Wait for puback waiter = asyncio.Future() self._puback_waiters[app_message.packet_id] = waiter - await waiter - del self._puback_waiters[app_message.packet_id] - app_message.puback_packet = waiter.result() - - # Discard inflight message - del self.session.inflight_out[app_message.packet_id] + try: + await waiter + app_message.puback_packet = waiter.result() + finally: + self._puback_waiters.pop(app_message.packet_id, None) + # Discard inflight message + self.session.inflight_out.pop(app_message.packet_id, None) elif app_message.direction == INCOMING: # Initiate delivery self.logger.debug("Add message to delivery") @@ -351,9 +352,12 @@ async def _handle_qos2_message_flow(self, app_message): raise AMQTTException(message) waiter = asyncio.Future() self._pubrec_waiters[app_message.packet_id] = waiter - await waiter - del self._pubrec_waiters[app_message.packet_id] - app_message.pubrec_packet = waiter.result() + try: + await waiter + app_message.pubrec_packet = waiter.result() + finally: + self._pubrec_waiters.pop(app_message.packet_id, None) + self.session.inflight_out.pop(app_message.packet_id, None) if not app_message.pubcomp_packet: # Send pubrel app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id) @@ -361,11 +365,12 @@ async def _handle_qos2_message_flow(self, app_message): # Wait for PUBCOMP waiter = asyncio.Future() self._pubcomp_waiters[app_message.packet_id] = waiter - await waiter - del self._pubcomp_waiters[app_message.packet_id] - app_message.pubcomp_packet = waiter.result() - # Discard inflight message - del self.session.inflight_out[app_message.packet_id] + try: + await waiter + app_message.pubcomp_packet = waiter.result() + finally: + self._pubcomp_waiters.pop(app_message.packet_id, None) + self.session.inflight_out.pop(app_message.packet_id, None) elif app_message.direction == INCOMING: self.session.inflight_in[app_message.packet_id] = app_message # Send pubrec diff --git a/amqtt/session.py b/amqtt/session.py index 7213e223..d23f5b0b 100644 --- a/amqtt/session.py +++ b/amqtt/session.py @@ -159,16 +159,15 @@ def _init_states(self): @property def next_packet_id(self): - self._packet_id += 1 - if self._packet_id > 65535: - self._packet_id = 1 + self._packet_id = (self._packet_id % 65535) + 1 + limit = self._packet_id while ( self._packet_id in self.inflight_in or self._packet_id in self.inflight_out ): - self._packet_id += 1 - if self._packet_id > 65535: + self._packet_id = (self._packet_id % 65535) + 1 + if self._packet_id == limit: raise AMQTTException( - "More than 65525 messages pending. No free packet ID" + "More than 65535 messages pending. No free packet ID" ) return self._packet_id diff --git a/tests/test_client.py b/tests/test_client.py index cc1691db..142e4946 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -218,3 +218,80 @@ async def test_deliver_timeout(): await client.unsubscribe(["$SYS/broker/uptime"]) await client.disconnect() await broker.shutdown() + + +@pytest.mark.asyncio +async def test_cancel_publish_qos1(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1)) + assert len(client_pub._handler._puback_waiters) == 0 + while len(client_pub._handler._puback_waiters) == 0 or fut.done(): + await asyncio.sleep(0) + assert len(client_pub._handler._puback_waiters) == 1 + assert client_pub.session.inflight_out_count == 1 + fut.cancel() + await asyncio.wait([fut]) + assert len(client_pub._handler._puback_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_cancel_publish_qos2_pubrec(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2)) + assert len(client_pub._handler._pubrec_waiters) == 0 + while ( + len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled() + ): + await asyncio.sleep(0) + assert len(client_pub._handler._pubrec_waiters) == 1 + assert client_pub.session.inflight_out_count == 1 + fut.cancel() + await asyncio.sleep(1) + await asyncio.wait([fut]) + assert len(client_pub._handler._pubrec_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown() + + +@pytest.mark.asyncio +async def test_cancel_publish_qos2_pubcomp(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2)) + assert len(client_pub._handler._pubcomp_waiters) == 0 + while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done(): + await asyncio.sleep(0) + assert len(client_pub._handler._pubcomp_waiters) == 1 + fut.cancel() + await asyncio.wait([fut]) + assert len(client_pub._handler._pubcomp_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown()