Skip to content

Commit

Permalink
Merge pull request #86 from azogue/fix/typing
Browse files Browse the repository at this point in the history
 🐛 Fix compatibility for Python < 3.10 and more typing
  • Loading branch information
azogue committed Jan 3, 2024
2 parents 7e9cdbb + 8d75bfe commit d4da8ac
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 102 deletions.
59 changes: 34 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,60 +41,70 @@ pip install fastapi-mqtt
### 🕹 Guide

```python
from typing import Any

from fastapi import FastAPI
from gmqtt import Client as MQTTClient

from fastapi_mqtt import FastMQTT, MQTTConfig

app = FastAPI()

mqtt_config = MQTTConfig()
mqtt = FastMQTT(config=mqtt_config)
mqtt.init_app(app)

fast_mqtt = FastMQTT(config=mqtt_config)

@mqtt.on_connect()
def connect(client, flags, rc, properties):
mqtt.client.subscribe("/mqtt") # subscribing mqtt topic
app = FastAPI()
fast_mqtt.init_app(app)


@fast_mqtt.on_connect()
def connect(client: MQTTClient, flags: int, rc: int, properties: Any):
client.subscribe("/mqtt") # subscribing mqtt topic
print("Connected: ", client, flags, rc, properties)

@mqtt.on_message()
async def message(client, topic, payload, qos, properties):
print("Received message: ", topic, payload.decode(), qos, properties)
@fast_mqtt.subscribe("mqtt/+/temperature", "mqtt/+/humidity", qos=1)
async def home_message(client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any):
print("temperature/humidity: ", topic, payload.decode(), qos, properties)

@mqtt.subscribe("my/mqtt/topic/#")
async def message_to_topic(client, topic, payload, qos, properties):
print("Received message to specific topic: ", topic, payload.decode(), qos, properties)
@fast_mqtt.on_message()
async def message(client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any):
print("Received message: ", topic, payload.decode(), qos, properties)

@mqtt.subscribe("my/mqtt/topic/#", qos=2)
async def message_to_topic_with_high_qos(client, topic, payload, qos, properties):
@fast_mqtt.subscribe("my/mqtt/topic/#", qos=2)
async def message_to_topic_with_high_qos(
client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any
):
print(
"Received message to specific topic and QoS=2: ", topic, payload.decode(), qos, properties
)


@mqtt.on_disconnect()
def disconnect(client, packet, exc=None):
@fast_mqtt.on_disconnect()
def disconnect(client: MQTTClient, packet, exc=None):
print("Disconnected")

@mqtt.on_subscribe()
def subscribe(client, mid, qos, properties):
@fast_mqtt.on_subscribe()
def subscribe(client: MQTTClient, mid: int, qos: int, properties: Any):
print("subscribed", client, mid, qos, properties)

@app.get("/test")
async def func():
fast_mqtt.publish("/mqtt", "Hello from Fastapi") # publishing mqtt topic
return {"result": True, "message": "Published"}
```

Publish method:

```python
async def func():
mqtt.publish("/mqtt", "Hello from Fastapi") # publishing mqtt topic
fast_mqtt.publish("/mqtt", "Hello from Fastapi") # publishing mqtt topic
return {"result": True, "message": "Published"}
```

Subscribe method:

```python
@mqtt.on_connect()
@fast_mqtt.on_connect()
def connect(client, flags, rc, properties):
mqtt.client.subscribe("/mqtt") # subscribing mqtt topic
client.subscribe("/mqtt") # subscribing mqtt topic
print("Connected: ", client, flags, rc, properties)
```

Expand All @@ -108,8 +118,7 @@ mqtt_config = MQTTConfig(
username="username",
password="strong_password",
)

mqtt = FastMQTT(config=mqtt_config)
fast_mqtt = FastMQTT(config=mqtt_config)
```

### ✅ Testing
Expand Down
32 changes: 20 additions & 12 deletions examples/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any

from fastapi import FastAPI
from gmqtt import Client as MQTTClient

from fastapi_mqtt.config import MQTTConfig
from fastapi_mqtt.fastmqtt import FastMQTT
from fastapi_mqtt import FastMQTT, MQTTConfig

mqtt_config = MQTTConfig()

Expand All @@ -12,35 +14,41 @@


@fast_mqtt.on_connect()
def connect(client, flags, rc, properties):
fast_mqtt.client.subscribe("/mqtt") # subscribing mqtt topic
def connect(client: MQTTClient, flags: int, rc: int, properties: Any):
client.subscribe("/mqtt") # subscribing mqtt topic
print("Connected: ", client, flags, rc, properties)


@fast_mqtt.subscribe("mqtt/+/temperature", "mqtt/+/humidity")
async def home_message(client, topic, payload, qos, properties):
@fast_mqtt.subscribe("mqtt/+/temperature", "mqtt/+/humidity", qos=1)
async def home_message(client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any):
print("temperature/humidity: ", topic, payload.decode(), qos, properties)
return 0


@fast_mqtt.on_message()
async def message(client, topic, payload, qos, properties):
async def message(client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any):
print("Received message: ", topic, payload.decode(), qos, properties)
return 0


@fast_mqtt.subscribe("my/mqtt/topic/#", qos=2)
async def message_to_topic_with_high_qos(
client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any
):
print(
"Received message to specific topic and QoS=2: ", topic, payload.decode(), qos, properties
)


@fast_mqtt.on_disconnect()
def disconnect(client, packet, exc=None):
def disconnect(client: MQTTClient, packet, exc=None):
print("Disconnected")


@fast_mqtt.on_subscribe()
def subscribe(client, mid, qos, properties):
def subscribe(client: MQTTClient, mid: int, qos: int, properties: Any):
print("subscribed", client, mid, qos, properties)


@app.get("/test")
async def func():
fast_mqtt.publish("/mqtt", "Hello from Fastapi") # publishing mqtt topic

return {"result": True, "message": "Published"}
3 changes: 1 addition & 2 deletions examples/app_will_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,5 @@ def subscribe(client, mid, qos, properties):
@app.get("/")
async def func():
# publishing mqtt topic
await fast_mqtt.publish("/mqtt", "Hello from Fastapi")

fast_mqtt.publish("/mqtt", "Hello from Fastapi")
return {"result": True, "message": "Published"}
4 changes: 2 additions & 2 deletions examples/ws_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def create_app():
ws_subscribers = DynamicMQTTClient(fast_mqtt)

@asynccontextmanager
async def _lifespan(application: FastAPI):
async def _lifespan(fastapi_app: FastAPI):
await fast_mqtt.mqtt_startup()
application.state.ws_subscribers = ws_subscribers
fastapi_app.state.ws_subscribers = ws_subscribers
yield
await ws_subscribers.close()
await fast_mqtt.mqtt_shutdown()
Expand Down
38 changes: 18 additions & 20 deletions fastapi_mqtt/fastmqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from .handlers import MQTTHandlers

try:
from uvicorn.config import logger

log_info = logger
from uvicorn.config import logger as log_info
except ImportError:
log_info = logging.getLogger()

Expand Down Expand Up @@ -49,9 +47,9 @@ def __init__(
client_id: Optional[str] = None,
clean_session: bool = True,
optimistic_acknowledgement: bool = True,
mqtt_logger: logging.Logger | None = None,
mqtt_logger: Optional[logging.Logger] = None,
**kwargs: Any,
):
) -> None:
if not client_id:
client_id = uuid.uuid4().hex

Expand All @@ -70,8 +68,8 @@ def __init__(
self.client.on_message = self.__on_message
self.client.on_connect = self.__on_connect
self.subscriptions: Dict[str, Tuple[Subscription, List[Callable]]] = {}
self.mqtt_handlers = MQTTHandlers(self.client)
self._logger = mqtt_logger or log_info
self.mqtt_handlers = MQTTHandlers(self.client, self._logger)

if (
self.config.will_message_topic
Expand All @@ -92,7 +90,7 @@ def __init__(
)

@staticmethod
def match(topic, template):
def match(topic: str, template: str) -> bool:
"""
Defined match topics
Expand All @@ -102,10 +100,10 @@ def match(topic, template):
if str(template).startswith("$share/"):
template = template.split("/", 2)[2]

topic = topic.split("/")
template = template.split("/")
topic_parts = topic.split("/")
template_parts = template.split("/")

for topic_part, part in zip_longest(topic, template):
for topic_part, part in zip_longest(topic_parts, template_parts):
if part == "#" and not str(topic_part).startswith("$"):
return True
elif (topic_part is None or part not in {"+", topic_part}) or (
Expand All @@ -114,7 +112,7 @@ def match(topic, template):
return False
continue

return len(template) == len(topic)
return len(template_parts) == len(topic_parts)

async def connection(self) -> None:
if self.client._username:
Expand Down Expand Up @@ -155,8 +153,8 @@ def __on_connect(self, client: MQTTClient, flags: int, rc: int, properties: Any)
Will perform subscription for given topics.
It cannot be done earlier, since subscription relies on connection.
"""
if self.mqtt_handlers.get_user_connect_handler:
self.mqtt_handlers.get_user_connect_handler(client, flags, rc, properties)
if self.mqtt_handlers.user_connect_handler is not None:
self.mqtt_handlers.user_connect_handler(client, flags, rc, properties)

for topic in self.subscriptions:
self._logger.debug("Subscribing for %s", topic)
Expand All @@ -170,10 +168,10 @@ async def __on_message(
This will invoke per topic handlers that are subscribed for
"""
gather = []
if self.mqtt_handlers.get_user_message_handler:
if self.mqtt_handlers.user_message_handler is not None:
self._logger.debug("Calling user_message_handler")
gather.append(
self.mqtt_handlers.get_user_message_handler(client, topic, payload, qos, properties)
self.mqtt_handlers.user_message_handler(client, topic, payload, qos, properties)
)

for topic_template in self.subscriptions:
Expand All @@ -191,7 +189,7 @@ def publish(
qos: int = 0,
retain: bool = False,
**kwargs,
):
) -> None:
"""
Defined to publish payload MQTT server
Expand Down Expand Up @@ -219,23 +217,23 @@ def unsubscribe(self, topic: str, **kwargs):

return self.client.unsubscribe(topic, **kwargs)

async def mqtt_startup(self):
async def mqtt_startup(self) -> None:
"""Initial connection for MQTT client, for lifespan startup."""
await self.connection()

async def mqtt_shutdown(self):
async def mqtt_shutdown(self) -> None:
"""Final disconnection for MQTT client, for lifespan shutdown."""
await self.client.disconnect()

def init_app(self, app: FastAPI) -> None: # pragma: no cover
"""Add startup and shutdown event handlers for app without lifespan."""

@app.on_event("startup")
async def startup():
async def startup() -> None:
await self.mqtt_startup()

@app.on_event("shutdown")
async def shutdown():
async def shutdown() -> None:
await self.mqtt_shutdown()

def subscribe(
Expand Down
44 changes: 25 additions & 19 deletions fastapi_mqtt/handlers.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,55 @@
from typing import Any, Callable, Optional
import warnings
from logging import Logger
from typing import Any, Awaitable, Callable, Optional

from gmqtt import Client as MQTTClient

try:
from uvicorn.config import logger

log_info = logger
except ImportError:
import logging

log_info = logging.getLogger()
# client: MQTTClient, topic: str, payload: bytes, qos: int, properties: Any
MQTTMessageHandler = Callable[[MQTTClient, str, bytes, int, Any], Awaitable[Any]]
# client: MQTTClient, flags: int, rc: int, properties: Any
MQTTConnectionHandler = Callable[[MQTTClient, int, int, Any], Any]


class MQTTHandlers:
def __init__(self, client: MQTTClient):
def __init__(self, client: MQTTClient, logger: Logger):
self._logger = logger
self.client = client
self.user_message_handler: Optional[Callable[..., Any]] = None
self.user_connect_handler: Optional[Callable[..., Any]] = None
self.user_message_handler: Optional[MQTTMessageHandler] = None
self.user_connect_handler: Optional[MQTTConnectionHandler] = None

def on_message(self, handler: Callable) -> Callable[..., Any]:
log_info.info("on_message handler accepted")
def on_message(self, handler: MQTTMessageHandler) -> MQTTMessageHandler:
self._logger.info("on_message handler accepted")
self.user_message_handler = handler
return handler

def on_subscribe(self, handler: Callable) -> Callable[..., Any]:
"""
Decorator method is used to obtain subscribed topics and properties.
"""
log_info.info("on_subscribe handler accepted")
self._logger.info("on_subscribe handler accepted")
self.client.on_subscribe = handler
return handler

def on_disconnect(self, handler: Callable) -> Callable[..., Any]:
self.client.on_disconnect = handler
return handler

def on_connect(self, handler: Callable) -> Callable[..., Any]:
log_info.info("on_connect handler accepted")
def on_connect(self, handler: MQTTConnectionHandler) -> MQTTConnectionHandler:
self._logger.info("on_connect handler accepted")
self.user_connect_handler = handler
return handler

# TODO: Remove these unused properties on v3.0
@property
def get_user_message_handler(self) -> Optional[Callable[..., Any]]:
def get_user_message_handler(self) -> Optional[MQTTMessageHandler]: # pragma: no cover
warnings.warn(
"Deprecated property. Access to .user_message_handler", DeprecationWarning, stacklevel=1
)
return self.user_message_handler

@property
def get_user_connect_handler(self) -> Optional[Callable[..., Any]]:
def get_user_connect_handler(self) -> Optional[MQTTConnectionHandler]: # pragma: no cover
warnings.warn(
"Deprecated property. Access to .user_connect_handler", DeprecationWarning, stacklevel=1
)
return self.user_connect_handler
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastapi-mqtt"
version = "2.1.0"
version = "2.1.1"
description = "fastapi-mqtt is extension for MQTT protocol"
authors = ["sabuhish <sabuhi.shukurov@gmail.com>"]
license = "MIT"
Expand Down
Empty file added tests/__init__.py
Empty file.

0 comments on commit d4da8ac

Please sign in to comment.