Skip to content

Commit

Permalink
Fix using ext functions in MQTT publish (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelboulton committed Feb 16, 2023
1 parent 7e62469 commit 2213e10
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 29 deletions.
25 changes: 25 additions & 0 deletions example/mqtt/test_mqtt.tavern.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,28 @@ stages:
payload: "there"
timeout: 5
qos: 1

---

test_name: Update an MQTT publish from an ext function

includes:
- !include common.yaml

paho-mqtt: *mqtt_spec

stages:
- *setup_device_for_test

- name: step 1 - ping/pong
mqtt_publish:
topic: /device/{random_device_id}/echo
json:
$ext:
function: testing_utils:return_hello
mqtt_response:
topic: /device/{random_device_id}/echo/response
timeout: 3
qos: 1
json:
hello: there
2 changes: 1 addition & 1 deletion example/mqtt/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ def message_says_hello(msg):
assert msg.payload.get("message") == "hello world"


def return_hello(_):
def return_hello(_=None):
return {"hello": "there"}
2 changes: 1 addition & 1 deletion tavern/_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class InvalidFormattedJsonError(TavernException):
"""Tried to use the magic json format tag in an invalid way"""


class InvalidExtBlockException(TavernException):
class MisplacedExtBlockException(TavernException):
"""Tried to use the '$ext' block in a place it is no longer valid to use it"""

def __init__(self, block) -> None:
Expand Down
24 changes: 19 additions & 5 deletions tavern/_core/extfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def inner():


def _get_ext_values(ext: Mapping):
if not isinstance(ext, Mapping):
raise exceptions.InvalidExtFunctionError(
"ext block should be a dict, but it was a {}".format(type(ext))
)

args = ext.get("extra_args") or ()
kwargs = ext.get("extra_kwargs") or {}
try:
Expand All @@ -145,14 +150,23 @@ def update_from_ext(request_args: dict, keys_to_check: List[str]) -> None:
"""

new_args = {}
logger = _getlogger()

for key in keys_to_check:
try:
func = get_wrapped_create_function(request_args[key].pop("$ext"))
except (KeyError, TypeError, AttributeError):
pass
else:
new_args[key] = func()
block = request_args[key]
except KeyError:
logger.debug("No %s block", key)
continue

try:
pop = block.pop("$ext")
except (KeyError, AttributeError, TypeError):
logger.debug("No ext functions in %s block", key)
continue

func = get_wrapped_create_function(pop)
new_args[key] = func()

merged_args = deep_dict_merge(request_args, new_args)

Expand Down
20 changes: 9 additions & 11 deletions tavern/_plugins/mqtt/request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import json
import logging
from typing import Mapping
from typing import Dict

from box.box import Box

Expand All @@ -16,21 +16,19 @@
logger = logging.getLogger(__name__)


def get_publish_args(rspec: Mapping, test_block_config: TestConfig) -> dict:
"""Format mqtt request args
Todo:
Anything else to do here?
"""
def get_publish_args(rspec: Dict, test_block_config: TestConfig) -> dict:
"""Format mqtt request args and update using ext functions"""

fspec = format_keys(rspec, test_block_config.variables)

if "json" in rspec:
if "payload" in rspec:
if "json" in fspec:
if "payload" in fspec:
raise exceptions.BadSchemaError(
"Can only specify one of 'payload' or 'json' in MQTT request"
)

update_from_ext(fspec, ["json"])

fspec["payload"] = json.dumps(fspec.pop("json"))

return fspec
Expand All @@ -43,15 +41,15 @@ class MQTTRequest(BaseRequest):
"""

def __init__(
self, client: MQTTClient, rspec: Mapping, test_block_config: TestConfig
self, client: MQTTClient, rspec: Dict, test_block_config: TestConfig
) -> None:
expected = {"topic", "payload", "json", "qos", "retain"}

check_expected_keys(expected, rspec)

publish_args = get_publish_args(rspec, test_block_config)
update_from_ext(publish_args, ["json"])

self._publish_args = publish_args
self._prepared = functools.partial(client.publish, **publish_args)

# Need to do this here because get_publish_args will modify the original
Expand Down
2 changes: 1 addition & 1 deletion tavern/_plugins/mqtt/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _get_payload_vals(expected) -> Tuple[Optional[Union[str, dict]], bool]:
json_payload = True

if payload.pop("$ext", None):
raise exceptions.InvalidExtBlockException(
raise exceptions.MisplacedExtBlockException(
"json",
)
elif "payload" in expected:
Expand Down
2 changes: 1 addition & 1 deletion tavern/_plugins/rest/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _validate_block(self, blockname: str, block: Mapping) -> None:

if isinstance(expected_block, dict):
if expected_block.pop("$ext", None):
raise exceptions.InvalidExtBlockException(
raise exceptions.MisplacedExtBlockException(
blockname,
)

Expand Down
2 changes: 1 addition & 1 deletion tavern/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def check_deprecated_validate(name):
if isinstance(block, dict):
check_ext_functions(block.get("$ext", None))
if nfuncs != len(self.validate_functions):
raise exceptions.InvalidExtBlockException(
raise exceptions.MisplacedExtBlockException(
name,
)

Expand Down
48 changes: 40 additions & 8 deletions tests/unit/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
from unittest.mock import MagicMock, Mock, patch

import paho.mqtt.client as paho
Expand All @@ -18,18 +19,19 @@ def test_host_required():
MQTTClient(**args)


class TestClient:
@pytest.fixture(name="fake_client")
def fix_fake_client(self):
args = {"connect": {"host": "localhost"}}
@pytest.fixture(name="fake_client")
def fix_fake_client():
args = {"connect": {"host": "localhost"}}

mqtt_client = MQTTClient(**args)

mqtt_client = MQTTClient(**args)
mqtt_client._subscribed[2] = _Subscription("abc")
mqtt_client._subscription_mappings["abc"] = 2

mqtt_client._subscribed[2] = _Subscription("abc")
mqtt_client._subscription_mappings["abc"] = 2
return mqtt_client

return mqtt_client

class TestClient:
def test_no_queue(self, fake_client):
"""Trying to fetch from a nonexistent queue raised exception"""

Expand Down Expand Up @@ -192,3 +194,33 @@ def subscribe_success(topic, *args, **kwargs):
MQTTClient._on_subscribe(mock_client, "abc", {}, 123, 0)

assert mock_client._subscribed == {}


class TestExtFunctions:
@pytest.fixture()
def basic_mqtt_request_args(self) -> Dict:
return {
"topic": "/a/b/c",
}

def test_basic(self, fake_client, basic_mqtt_request_args, includes):
MQTTRequest(fake_client, basic_mqtt_request_args, includes)

def test_ext_function_bad(self, fake_client, basic_mqtt_request_args, includes):
basic_mqtt_request_args["json"] = {"$ext": "kk"}

with pytest.raises(exceptions.InvalidExtFunctionError):
MQTTRequest(fake_client, basic_mqtt_request_args, includes)

def test_ext_function_good(self, fake_client, basic_mqtt_request_args, includes):
basic_mqtt_request_args["json"] = {
"$ext": {
"function": "operator:add",
"extra_args": (1, 2),
}
}

m = MQTTRequest(fake_client, basic_mqtt_request_args, includes)

assert "payload" in m._publish_args
assert m._publish_args["payload"] == "3"

0 comments on commit 2213e10

Please sign in to comment.