diff --git a/resultsdb/__init__.py b/resultsdb/__init__.py index a659949..e115535 100644 --- a/resultsdb/__init__.py +++ b/resultsdb/__init__.py @@ -38,6 +38,7 @@ from resultsdb.controllers.main import main from resultsdb.controllers.api_v2 import api as api_v2 from resultsdb.controllers.api_v3 import api as api_v3, create_endpoints +from resultsdb.messaging import load_messaging_plugin from resultsdb.models import db from . import config @@ -119,6 +120,8 @@ def create_app(config_obj=None): else: app.logger.info("OpenIDConnect authentication is disabled") + setup_messaging(app) + app.logger.debug("Finished ResultsDB initialization") return app @@ -173,6 +176,21 @@ def setup_logging(app): app.logger.addHandler(file_handler) +def setup_messaging(app): + app.messaging_plugin = None + if not app.config["MESSAGE_BUS_PUBLISH"]: + app.logger.info("No messaging plugin") + return + + plugin_name = app.config["MESSAGE_BUS_PLUGIN"] + app.logger.info("Using messaging plugin %s", plugin_name) + plugin_args = app.config["MESSAGE_BUS_KWARGS"] + app.messaging_plugin = load_messaging_plugin( + name=plugin_name, + kwargs=plugin_args, + ) + + def register_handlers(app): # TODO: find out why error handler works for 404 but not for 400 @app.errorhandler(400) diff --git a/resultsdb/controllers/common.py b/resultsdb/controllers/common.py index 9b38eec..5c5b495 100644 --- a/resultsdb/controllers/common.py +++ b/resultsdb/controllers/common.py @@ -4,7 +4,6 @@ from resultsdb.models import db from resultsdb.messaging import ( - load_messaging_plugin, create_message, publish_taskotron_message, ) @@ -28,13 +27,10 @@ def commit_result(result): result.outcome, ) - if app.config["MESSAGE_BUS_PUBLISH"]: + if app.messaging_plugin: app.logger.debug("Preparing to publish message for result id %d", result.id) - plugin = load_messaging_plugin( - name=app.config["MESSAGE_BUS_PLUGIN"], - kwargs=app.config["MESSAGE_BUS_KWARGS"], - ) - plugin.publish(create_message(result)) + message = create_message(result) + app.messaging_plugin.publish(message) if app.config["MESSAGE_BUS_PUBLISH_TASKOTRON"]: app.logger.debug("Preparing to publish Taskotron message for result id %d", result.id) diff --git a/resultsdb/messaging.py b/resultsdb/messaging.py index a7f3dce..3e0fd26 100644 --- a/resultsdb/messaging.py +++ b/resultsdb/messaging.py @@ -205,7 +205,7 @@ def __init__(self, **kwargs): required = ["connection", "destination"] for attr in required: if getattr(self, attr, None) is None: - raise ValueError("%r required for %r." % (attr, self)) + raise ValueError(f"Missing {attr!r} option for STOMP messaging plugin") def publish(self, msg): msg = json.dumps(msg) diff --git a/testing/test_app.py b/testing/test_app.py new file mode 100644 index 0000000..ef4014c --- /dev/null +++ b/testing/test_app.py @@ -0,0 +1,49 @@ +from unittest.mock import Mock + +from pytest import raises + +from resultsdb import setup_messaging + + +def test_app_messaging(app): + assert app.messaging_plugin is not None + assert type(app.messaging_plugin).__name__ == "DummyPlugin" + + +def test_app_messaging_none(): + app = Mock() + app.config = {"MESSAGE_BUS_PUBLISH": False} + setup_messaging(app) + app.logger.info.assert_called_once_with("No messaging plugin") + + +def test_app_messaging_stomp(): + app = Mock() + app.config = { + "MESSAGE_BUS_PUBLISH": True, + "MESSAGE_BUS_PLUGIN": "stomp", + "MESSAGE_BUS_KWARGS": { + "destination": "results.new", + "connection": { + "host_and_ports": [("localhost", 1234)], + }, + }, + } + setup_messaging(app) + app.logger.info.assert_called_once_with("Using messaging plugin %s", "stomp") + + +def test_app_messaging_stomp_bad(): + app = Mock() + app.config = { + "MESSAGE_BUS_PUBLISH": True, + "MESSAGE_BUS_PLUGIN": "stomp", + "MESSAGE_BUS_KWARGS": { + "connection": { + "host_and_ports": [("localhost", 1234)], + }, + }, + } + expected_error = "Missing 'destination' option for STOMP messaging plugin" + with raises(ValueError, match=expected_error): + setup_messaging(app) diff --git a/testing/test_general.py b/testing/test_general.py index 91693af..5a2c6cc 100644 --- a/testing/test_general.py +++ b/testing/test_general.py @@ -1,10 +1,33 @@ import datetime import ssl +from unittest.mock import patch + +import stomp +from pytest import fixture, raises import resultsdb.controllers.api_v2 as apiv2 import resultsdb.messaging as messaging from resultsdb.parsers.api_v2 import parse_since +MESSAGE_BUS_KWARGS = { + "destination": "results.new", + "connection": { + "host_and_ports": [("localhost", 1234)], + "use_ssl": True, + "ssl_version": ssl.PROTOCOL_TLSv1_2, + "ssl_key_file": "/etc/secret/umb-client.key", + "ssl_cert_file": "/etc/secret/umb-client.crt", + "ssl_ca_certs": "/etc/secret/ca.pem", + }, +} + + +@fixture +def mock_stomp(): + with patch("resultsdb.messaging.stomp.connect.StompConnection11") as mock: + mock().is_connected.return_value = False + yield mock + class MyRequest(object): def __init__(self, url): @@ -136,7 +159,7 @@ def test_load_plugin(self): " resultsdb.messaging:FedmsgPlugin" ) - def test_load_stomp(self): + def test_stomp_load(self): message_bus_kwargs = { "destination": "results.new", "connection": { @@ -147,22 +170,20 @@ def test_load_stomp(self): assert isinstance(plugin, messaging.StompPlugin) assert plugin.destination == "results.new" - def test_stomp_ssl(self): + def test_stomp_missing_destination(self): message_bus_kwargs = { - "destination": "results.new", "connection": { "host_and_ports": [("localhost", 1234)], - "use_ssl": True, - "ssl_version": ssl.PROTOCOL_TLSv1_2, - "ssl_key_file": "/etc/secret/umb-client.key", - "ssl_cert_file": "/etc/secret/umb-client.crt", - "ssl_ca_certs": "/etc/secret/ca.pem", }, } + expected_error = "Missing 'destination' option for STOMP messaging plugin" + with raises(ValueError, match=expected_error): + messaging.load_messaging_plugin("stomp", message_bus_kwargs) + def test_stomp_ssl(self): # Run twice to ensure that the original configuration is not modified. for _ in (1, 2): - plugin = messaging.load_messaging_plugin("stomp", message_bus_kwargs) + plugin = messaging.load_messaging_plugin("stomp", MESSAGE_BUS_KWARGS) assert plugin.connection == { "host_and_ports": [("localhost", 1234)], } @@ -175,6 +196,36 @@ def test_stomp_ssl(self): "ssl_version": ssl.PROTOCOL_TLSv1_2, } + def test_stomp_publish(self, mock_stomp): + plugin = messaging.load_messaging_plugin("stomp", MESSAGE_BUS_KWARGS) + assert mock_stomp().is_connected() is False + plugin.publish({}) + mock_stomp().connect.assert_called_once() + mock_stomp().send.assert_called_once() + mock_stomp().disconnect.assert_called_once() + + def test_stomp_publish_connect_failed(self, mock_stomp): + plugin = messaging.load_messaging_plugin("stomp", MESSAGE_BUS_KWARGS) + + mock_stomp().connect.side_effect = stomp.exception.ConnectFailedException() + with raises(stomp.exception.ConnectFailedException): + plugin.publish({}) + + mock_stomp().connect.assert_called_once() + mock_stomp().send.assert_not_called() + mock_stomp().disconnect.assert_not_called() + + def test_stomp_publish_send_failed(self, mock_stomp): + plugin = messaging.load_messaging_plugin("stomp", MESSAGE_BUS_KWARGS) + + mock_stomp().send.side_effect = stomp.exception.StompException() + with raises(stomp.exception.StompException): + plugin.publish({}) + + mock_stomp().connect.assert_called_once() + mock_stomp().send.assert_called_once() + mock_stomp().disconnect.assert_called_once() + class TestGetResultsParseArgs: # TODO: write something!