Skip to content

Commit

Permalink
Improve serverprotocol error handling (#1564)
Browse files Browse the repository at this point in the history
Also, improve type hints and variable naming

Related to @starkillerOG review in #1531
  • Loading branch information
rytilahti committed Oct 30, 2022
1 parent 975fe68 commit 1099dab
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 13 deletions.
5 changes: 3 additions & 2 deletions miio/push_server/server.py
Expand Up @@ -17,6 +17,7 @@
FAKE_DEVICE_MODEL = "chuangmi.plug.v3"

PushServerCallback = Callable[[str, str, str], None]
MethodDict = Dict[str, Union[Dict, Callable]]


def calculated_token_enc(token):
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self, device_ip=None):
self._listen_couroutine = None
self._registered_devices = {}

self._methods = {}
self._methods: MethodDict = {}

self._event_id = 1000000

Expand Down Expand Up @@ -325,6 +326,6 @@ def server_model(self):
return self._server_model

@property
def methods(self):
def methods(self) -> MethodDict:
"""Return a dict of implemented methods."""
return self._methods
31 changes: 26 additions & 5 deletions miio/push_server/serverprotocol.py
Expand Up @@ -11,6 +11,10 @@
"21310020ffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
)

ERR_INVALID = -1
ERR_UNSUPPORTED = -2
ERR_METHOD_EXEC_FAILED = -3


class ServerProtocol:
"""Handle responding to UDP packets."""
Expand Down Expand Up @@ -73,11 +77,11 @@ def send_response(self, host, port, msg_id, token, payload=None):
if payload is None:
payload = {}

result = {**payload, "id": msg_id}
msg = self._create_message(result, token, device_id=self.server.server_id)
data = {**payload, "id": msg_id}
msg = self._create_message(data, token, device_id=self.server.server_id)

self.transport.sendto(msg, (host, port))
_LOGGER.debug(">> %s:%s: %s", host, port, result)
_LOGGER.debug(">> %s:%s: %s", host, port, data)

def send_error(self, host, port, msg_id, token, code, message):
"""Send error message with given code and message to the client."""
Expand Down Expand Up @@ -121,19 +125,36 @@ def _handle_datagram_from_client(self, host: str, port: int, data):
msg_value,
)

if "method" not in msg_value:
return self.send_error(
host, port, msg_id, token, ERR_INVALID, "missing method"
)

methods = self.server.methods
if msg_value["method"] not in methods:
return self.send_error(host, port, msg_id, token, -1, "unsupported method")
return self.send_error(
host, port, msg_id, token, ERR_UNSUPPORTED, "unsupported method"
)

_LOGGER.debug("Got method call: %s", msg_value["method"])
method = methods[msg_value["method"]]
if callable(method):
try:
response = method(msg_value)
except Exception as ex:
return self.send_error(host, port, msg_id, token, -1, str(ex))
_LOGGER.exception(ex)
return self.send_error(
host,
port,
msg_id,
token,
ERR_METHOD_EXEC_FAILED,
f"Exception {type(ex)}: {ex}",
)
else:
response = method

_LOGGER.debug("Responding %s with %s", msg_id, response)
return self.send_response(host, port, msg_id, token, payload=response)

def datagram_received(self, data, addr):
Expand Down
46 changes: 40 additions & 6 deletions miio/push_server/test_serverprotocol.py
Expand Up @@ -2,7 +2,12 @@

from miio import Message

from .serverprotocol import ServerProtocol
from .serverprotocol import (
ERR_INVALID,
ERR_METHOD_EXEC_FAILED,
ERR_UNSUPPORTED,
ServerProtocol,
)

HOST = "127.0.0.1"
PORT = 1234
Expand Down Expand Up @@ -108,15 +113,44 @@ def test_datagram_with_known_method(protocol: ServerProtocol, mocker):
assert cargs["payload"] == response_payload


def test_datagram_with_unknown_method(protocol: ServerProtocol, mocker):
"""Test that regular client messages are handled properly."""
@pytest.mark.parametrize(
"method,err_code", [("unknown_method", ERR_UNSUPPORTED), (None, ERR_INVALID)]
)
def test_datagram_with_unknown_method(
method, err_code, protocol: ServerProtocol, mocker
):
"""Test that invalid payloads are erroring out correctly."""
protocol.send_error = mocker.Mock() # type: ignore[assignment]
protocol.server.methods = {}

msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234)
data = {"id": 1}

if method is not None:
data["method"] = method

msg = protocol._create_message(data, DUMMY_TOKEN, 1234)
protocol._handle_datagram_from_client(HOST, PORT, msg)

protocol.send_error.assert_called() # type: ignore
cargs = protocol.send_error.call_args[0] # type: ignore
assert cargs[4] == err_code


def test_datagram_with_exception_raising(protocol: ServerProtocol, mocker):
"""Test that exception raising callbacks are ."""
protocol.send_error = mocker.Mock() # type: ignore[assignment]

def _raise(*args, **kwargs):
raise Exception("error message")

protocol.server.methods = {"raise": _raise}

data = {"id": 1, "method": "raise"}

msg = protocol._create_message(data, DUMMY_TOKEN, 1234)
protocol._handle_datagram_from_client(HOST, PORT, msg)

protocol.send_error.assert_called() # type: ignore
cargs = protocol.send_error.call_args[0] # type: ignore
assert cargs[4] == -1
assert cargs[5] == "unsupported method"
assert cargs[4] == ERR_METHOD_EXEC_FAILED
assert "error message" in cargs[5]

0 comments on commit 1099dab

Please sign in to comment.