Skip to content

Commit

Permalink
Cancel broadcast task when connection is lost
Browse files Browse the repository at this point in the history
  • Loading branch information
Selutario committed Apr 19, 2022
1 parent 5ce9cd8 commit 6dd550f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 47 deletions.
6 changes: 4 additions & 2 deletions framework/wazuh/core/cluster/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, server, loop: asyncio.AbstractEventLoop, fernet_key: str,
self.name = None
self.ip = None
self.transport = None
self.handler_tasks = []
self.broadcast_queue = asyncio.Queue()

def to_dict(self) -> Dict:
Expand Down Expand Up @@ -148,7 +149,7 @@ def hello(self, data: bytes) -> Tuple[bytes, bytes]:
self.server.clients[self.name] = self
self.tag = f'{self.tag} {self.name}'
context_tag.set(self.tag)
self.loop.create_task(self.broadcast_reader())
self.handler_tasks.append(self.loop.create_task(self.broadcast_reader()))
return b'ok', f'Client {self.name} added'.encode()

def process_response(self, command: bytes, payload: bytes) -> bytes:
Expand Down Expand Up @@ -187,9 +188,10 @@ def connection_lost(self, exc):
else:
self.logger.error(f"Error during connection with '{self.name}': {exc}.\n"
f"{''.join(traceback.format_tb(exc.__traceback__))}", exc_info=False)

if self.name in self.server.clients:
del self.server.clients[self.name]
for task in self.handler_tasks:
task.cancel()
else:
if exc is not None:
self.logger.error(f"Error during handshake with incoming connection: {exc}. \n"
Expand Down
103 changes: 58 additions & 45 deletions framework/wazuh/core/cluster/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from asyncio import Transport
from contextvars import ContextVar
from logging import Logger
from unittest.mock import call, patch, ANY, MagicMock
from unittest.mock import call, patch, ANY, Mock

import pytest
from freezegun import freeze_time
Expand Down Expand Up @@ -45,6 +45,7 @@ def test_AbstractServerHandler_init():
assert isinstance(abstract_server_handler.last_keepalive, float)
assert abstract_server_handler.tag == "NoClient"
assert mock_contextvar.get() == "NoClient"
assert abstract_server_handler.handler_tasks == []
assert isinstance(abstract_server_handler.broadcast_queue, asyncio.Queue)


Expand Down Expand Up @@ -106,18 +107,21 @@ def test_AbstractServerHandler_echo_master():
assert abstract_server_handler.last_keepalive == 0.0


def test_AbstractServerHandler_hello():
@patch("asyncio.create_task")
def test_AbstractServerHandler_hello(create_task_mock):
"""Check that the information of the new client invoking this function is stored correctly."""

class ServerMock:
def __init__(self):
self.clients = {}
self.configuration = {"node_name": "elif_test"}

loop.create_task = Mock()
abstract_server_handler = AbstractServerHandler(server="Test", loop=loop, fernet_key=fernet_key,
cluster_items={"test": "server"})
abstract_server_handler.server = ServerMock()
abstract_server_handler.tag = "FixBehaviour"
abstract_server_handler.broadcast_reader = Mock()

with patch("wazuh.core.cluster.server.context_tag", ContextVar("tag", default="")) as mock_contextvar:
assert abstract_server_handler.hello(b"else_test") == (b"ok",
Expand All @@ -126,6 +130,7 @@ def __init__(self):
assert abstract_server_handler.server.clients["else_test"] == abstract_server_handler
assert abstract_server_handler.tag == f"FixBehaviour {abstract_server_handler.name}"
assert mock_contextvar.get() == abstract_server_handler.tag
loop.create_task.assert_called_once()

with pytest.raises(WazuhClusterError, match=".* 3029 .*"):
abstract_server_handler.hello(b"elif_test")
Expand Down Expand Up @@ -179,8 +184,11 @@ def __init__(self):
assert "unit" not in abstract_server_handler.server.clients.keys()

with patch.object(logger, "debug") as mock_debug_logger:
task_mock = Mock()
abstract_server_handler.handler_tasks = [task_mock]
abstract_server_handler.connection_lost(exc=None)
mock_debug_logger.assert_called_once_with("Disconnected unit.")
task_mock.cancel.assert_called_once()


@patch("asyncio.Queue")
Expand All @@ -203,8 +211,8 @@ async def async_mock_func():
def sync_mock_func():
return 'Result'

server_mock = MagicMock()
logger_mock = MagicMock()
server_mock = Mock()
logger_mock = Mock()
server_mock.broadcast_results = {'test1': {'worker1': {}}, 'test2': {'worker1': {}}, 'test3': {'worker1': {}}}
abstract_server_handler = AbstractServerHandler(server=server_mock, loop=loop, fernet_key=fernet_key,
cluster_items={"test": "server"}, logger=logger_mock)
Expand All @@ -223,8 +231,10 @@ def sync_mock_func():
" is not callable.")


@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_init(loop_mock):
@patch("asyncio.get_running_loop", new=Mock())
@patch('wazuh.core.cluster.server.AbstractServer.check_clients_keepalive')
@patch('wazuh.core.cluster.server.AbstractServerHandler')
def test_AbstractServer_init(AbstractServerHandler_mock, keepalive_mock):
"""Check the correct initialization of the AbstractServer object."""
with patch("wazuh.core.cluster.server.context_tag", ContextVar("tag", default="")) as mock_contextvar:
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
Expand All @@ -249,15 +259,17 @@ def test_AbstractServer_init(loop_mock):
assert abstract_server.broadcast_results == {}


@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_broadcast(loop_mock):
@patch("asyncio.get_running_loop", new=Mock())
@patch('wazuh.core.cluster.server.AbstractServer.check_clients_keepalive')
@patch('wazuh.core.cluster.server.AbstractServerHandler')
def test_AbstractServer_broadcast(AbstractServerHandler_mock, asynckeepalive_mock):
"""Check that add_request is called with expected parameters."""
def test_func():
pass

logger_mock = MagicMock()
worker1_instance = MagicMock()
worker2_instance = MagicMock()
logger_mock = Mock()
worker1_instance = Mock()
worker2_instance = Mock()
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True, logger=logger_mock)
abstract_server.clients = {"worker1": worker1_instance, "worker2": worker2_instance}
Expand All @@ -269,10 +281,10 @@ def test_func():
call('Added broadcast request to execute "test_func" in worker2.')]


@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_broadcast_ko(loop_mock):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_broadcast_ko():
"""Verify that expected error log is printed when an exception is raised."""
logger_mock = MagicMock()
logger_mock = Mock()
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True, logger=logger_mock)
abstract_server.clients = {"worker1": "test"}
Expand All @@ -281,16 +293,17 @@ def test_AbstractServer_broadcast_ko(loop_mock):
logger_mock.error.assert_called_once_with("Error while adding broadcast request in worker1: 'str' object "
"has no attribute 'add_request'", exc_info=False)


@patch("wazuh.core.cluster.server.uuid4", return_value="abc123")
@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_broadcast_add(loop_mock, uuid_mock):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_broadcast_add(uuid_mock):
"""Check that add_request is called with expected parameters."""
def test_func():
pass

logger_mock = MagicMock()
worker1_instance = MagicMock()
worker2_instance = MagicMock()
logger_mock = Mock()
worker1_instance = Mock()
worker2_instance = Mock()
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True, logger=logger_mock)
abstract_server.broadcast_results = {}
Expand All @@ -303,10 +316,10 @@ def test_func():


@patch("wazuh.core.cluster.server.uuid4", return_value="abc123")
@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_broadcast_add_ko(loop_mock, uuid_mock):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_broadcast_add_ko(uuid_mock):
"""Check that expected error log is printed and that broadcast_results is deleted."""
logger_mock = MagicMock()
logger_mock = Mock()
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True, logger=logger_mock)
abstract_server.broadcast_results = {}
Expand All @@ -326,10 +339,10 @@ def test_AbstractServer_broadcast_add_ko(loop_mock, uuid_mock):
({"abc123": {"worker1": "Response", "worker2": "Response", "worker3": "Response"}},
{"worker1": "Response", "worker2": "Response", "worker3": "Response"}),
])
@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_broadcast_pop(loop_mock, broadcast_results, expected_response):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_broadcast_pop(broadcast_results, expected_response):
"""Check that expected response is returned for each case."""
logger_mock = MagicMock()
logger_mock = Mock()
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True, logger=logger_mock)
abstract_server.broadcast_results = broadcast_results
Expand All @@ -338,8 +351,8 @@ def test_AbstractServer_broadcast_pop(loop_mock, broadcast_results, expected_res
assert abstract_server.broadcast_pop("abc123") == expected_response


@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_to_dict(loop_mock):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_to_dict():
"""Check the correct transformation of an AbstractServer to a dict."""
configuration = {"test_to_dict": 0,
"nodes": [0, 1],
Expand All @@ -349,8 +362,8 @@ def test_AbstractServer_to_dict(loop_mock):
assert abstract_server.to_dict() == {"info": {"ip": configuration["nodes"][0], "name": configuration['node_name']}}


@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_setup_task_logger(loop_mock):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_setup_task_logger():
"""Check that a logger is created with a specific tag."""
logger = Logger("setup_task_logger")
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
Expand All @@ -363,8 +376,8 @@ def test_AbstractServer_setup_task_logger(loop_mock):


@patch("wazuh.core.cluster.server.utils.process_array")
@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_get_connected_nodes(loop_mock, mock_process_array):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_get_connected_nodes(mock_process_array):
"""Check that all the necessary data is sent to the utils.process_array
function to return all the information of the connected nodes."""
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
Expand All @@ -381,8 +394,8 @@ def test_AbstractServer_get_connected_nodes(loop_mock, mock_process_array):


@patch("wazuh.core.cluster.server.utils.process_array")
@patch("asyncio.get_running_loop", return_value=loop)
def test_AbstractServer_get_connected_nodes_ko(loop_mock, mock_process_array):
@patch("asyncio.get_running_loop", new=Mock())
def test_AbstractServer_get_connected_nodes_ko(mock_process_array):
"""Check all exceptions that can be returned by the get_connected_nodes function."""
abstract_server = AbstractServer(performance_test=1, concurrency_test=2, configuration={"test3": 3},
cluster_items={"test4": 4}, enable_ssl=True)
Expand All @@ -408,8 +421,8 @@ def test_AbstractServer_get_connected_nodes_ko(loop_mock, mock_process_array):

@pytest.mark.asyncio
@patch("asyncio.sleep", side_effect=IndexError)
@patch("asyncio.get_running_loop", return_value=loop)
async def test_AbstractServer_check_clients_keepalive(loop_mock, sleep_mock):
@patch("asyncio.get_running_loop", new=Mock())
async def test_AbstractServer_check_clients_keepalive(sleep_mock):
"""Check that the function check_clients_keepalive checks the date of the
last last_keepalive of the clients to verify if they are connected or not."""

Expand Down Expand Up @@ -451,8 +464,8 @@ def __init__(self):

@pytest.mark.asyncio
@patch("asyncio.sleep", side_effect=IndexError)
@patch("asyncio.get_running_loop", return_value=loop)
async def test_AbstractServer_echo(loop_mock, sleep_mock):
@patch("asyncio.get_running_loop", new=Mock())
async def test_AbstractServer_echo(sleep_mock):
"""Check that the echo function sends a message to all clients and that the information is written to the log."""

class ClientMock:
Expand All @@ -475,9 +488,9 @@ async def send_request(self, command, data):

@freeze_time("2022-01-01")
@patch("asyncio.sleep", side_effect=IndexError)
@patch("asyncio.get_running_loop", return_value=loop)
@patch("asyncio.get_running_loop", new=Mock())
@patch('wazuh.core.cluster.server.perf_counter', return_value=0)
async def test_AbstractServer_performance_test(perf_counter_mock, loop_mock, sleep_mock):
async def test_AbstractServer_performance_test(perf_counter_mock, sleep_mock):
"""Check that the function performance_test sends a big message to all clients
and then get the time it took to send them."""

Expand All @@ -500,9 +513,9 @@ async def send_request(self, command, data):

@freeze_time("2022-01-01")
@patch("asyncio.sleep", side_effect=IndexError)
@patch("asyncio.get_running_loop", return_value=loop)
@patch("asyncio.get_running_loop", new=Mock())
@patch('wazuh.core.cluster.server.perf_counter', return_value=0)
async def test_AbstractServer_concurrency_test(perf_counter_mock, loop_mock, sleep_mock):
async def test_AbstractServer_concurrency_test(perf_counter_mock, sleep_mock):
"""Check that the function concurrency_test sends messages to all clients
and then get the time it took to send them."""

Expand All @@ -527,8 +540,8 @@ async def send_request(self, command, data):
@patch("os.path.join", return_value="testing_path")
@patch("uvloop.EventLoopPolicy")
@patch("asyncio.set_event_loop_policy")
@patch("asyncio.sleep", side_effect=IndexError)
async def test_AbstractServer_start(sleep_mock, set_event_loop_policy_mock, eventlooppolicy_mock, mock_path_join):
@patch('wazuh.core.cluster.server.AbstractServer.check_clients_keepalive')
async def test_AbstractServer_start(keepalive_mock, set_event_loop_policy_mock, eventlooppolicy_mock, mock_path_join):
"""Check that the start function starts infinite asynchronous tasks according
to the parameters with which the AbstractServer object has been created."""

Expand Down Expand Up @@ -588,8 +601,8 @@ def load_cert_chain(self):
@patch("wazuh.core.cluster.server.AbstractServerHandler")
@patch("uvloop.EventLoopPolicy")
@patch("asyncio.set_event_loop_policy")
@patch("asyncio.sleep", side_effect=IndexError)
async def test_AbstractServer_start_ko(sleep_mock, set_event_loop_policy_mock, eventlooppolicy_mock,
@patch('wazuh.core.cluster.server.AbstractServer.check_clients_keepalive')
async def test_AbstractServer_start_ko(keepalive_mock, set_event_loop_policy_mock, eventlooppolicy_mock,
mock_AbstractServerHandler):
"""Check for exceptions that may arise inside the start function."""

Expand Down

0 comments on commit 6dd550f

Please sign in to comment.