Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcasting system for the cluster #13124

Merged
merged 4 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
159 changes: 153 additions & 6 deletions framework/wazuh/core/cluster/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Created by Wazuh, Inc. <info@wazuh.com>.
# This program is free software; you can redistribute it and/or modify it under the terms of GPLv2

import asyncio
import contextlib
import functools
import inspect
import itertools
import logging
import os
Expand Down Expand Up @@ -51,6 +55,8 @@ def __init__(self, server, loop: asyncio.AbstractEventLoop, fernet_key: str,
self.name = None
self.ip = None
self.transport = None
self.handler_tasks = []
self.broadcast_queue = asyncio.Queue()

def to_dict(self) -> Dict:
"""Get basic info from AbstractServerHandler instance.
Expand Down Expand Up @@ -142,6 +148,7 @@ def hello(self, data: bytes) -> Tuple[bytes, bytes]:
self.server.clients[self.name] = self
self.tag = f'{self.tag} {self.name}'
context_tag.set(self.tag)
self.handler_tasks.append(self.loop.create_task(self.broadcast_reader()))
return b'ok', f'Client {self.name} added'.encode()

def process_response(self, command: bytes, payload: bytes) -> bytes:
Expand Down Expand Up @@ -180,22 +187,66 @@ def connection_lost(self, exc):
else:
self.logger.error(f"Error during connection with '{self.name}': {exc}.\n"
f"{''.join(traceback.format_tb(exc.__traceback__))}", exc_info=False)

if self.name in self.server.clients:
del self.server.clients[self.name]
for task in self.handler_tasks:
task.cancel()
elif exc is not None:
self.logger.error(f"Error during handshake with incoming connection: {exc}. \n"
f"{''.join(traceback.format_tb(exc.__traceback__))}", exc_info=False)
else:
if exc is not None:
self.logger.error(f"Error during handshake with incoming connection: {exc}. \n"
f"{''.join(traceback.format_tb(exc.__traceback__))}", exc_info=False)
else:
self.logger.error("Error during handshake with incoming connection.", exc_info=False)
self.logger.error("Error during handshake with incoming connection.", exc_info=False)

def add_request(self, broadcast_id, f, *args, **kwargs):
"""Add a request to the queue to execute a function in this server handler.

Parameters
----------
broadcast_id : Str or None
Request identifier to be included in the queue.
f : callable
Function reference to be run. The function should be defined in this or in any inheriting class.
*args
Arguments to be passed to function `f`.
**kwargs
Keyword arguments to be passed to function `f`.
"""
self.broadcast_queue.put_nowait(
{'broadcast_id': broadcast_id, 'func': functools.partial(f, self, *args, **kwargs)}
)

async def broadcast_reader(self):
"""Execute functions added to the broadcast_queue of this server handler.

Wait until something with this structure is added to the queue:
{'broadcast_id': Union[Str, None], 'func': Callable}.

The function 'func' is executed and its result is stored in a dict
under the key 'broadcast_id', if it exists.
"""
while True:
q_item = await self.broadcast_queue.get()

try:
if inspect.iscoroutinefunction(q_item['func']):
result = await q_item['func']()
else:
result = q_item['func']()
except Exception as e:
self.logger.error(f"Error while broadcasting function. ID: {q_item['broadcast_id']}. Error: {e}.")
result = e

with contextlib.suppress(KeyError):
self.server.broadcast_results[q_item['broadcast_id']][self.name] = result


class AbstractServer:
"""
Define an asynchronous server. Handle connections from all clients.
"""

NO_RESULT = 'no_result'

def __init__(self, performance_test: int, concurrency_test: int, configuration: Dict, cluster_items: Dict,
enable_ssl: bool, logger: logging.Logger = None, tag: str = "Abstract Server"):
"""Class constructor.
Expand Down Expand Up @@ -230,6 +281,102 @@ def __init__(self, performance_test: int, concurrency_test: int, configuration:
self.tasks = [self.check_clients_keepalive]
self.handler_class = AbstractServerHandler
self.loop = asyncio.get_running_loop()
self.broadcast_results = {}

def broadcast(self, f, *args, **kwargs):
"""Add a function to the broadcast_queue of each server handler.

Parameters
----------
f : Callable
Function to be run in each server handler.
*args
Arguments to be passed to function `f`.
**kwargs
Keyword arguments to be passed to function `f`.

Notes
-----
This method does not allow determining whether the function has been
executed in all server handlers or the result for each one. For those
features, see `broadcast_add` and `broadcast_pop`.
"""
for name, client in self.clients.items():
try:
client.add_request(None, f, *args, **kwargs)
self.logger.debug2(f'Added broadcast request to execute "{f.__name__}" in {name}.')
except Exception as e:
self.logger.error(f'Error while adding broadcast request in {name}: {e}', exc_info=False)

def broadcast_add(self, f, *args, **kwargs):
"""Add a function to the broadcast_queue of each server handler and obtain an identifier.

Parameters
----------
f : Callable
Function to be run in each server handler.
*args
Arguments to be passed to function `f`.
**kwargs
Keyword arguments to be passed to function `f`.

Returns
-------
broadcast_id : str
Identifier to check the status of the broadcast request.

Notes
-----
It is important to run `broadcast_pop` to remove the result entry from the
broadcast_results dict after using this method. Otherwise, it will be kept
until restarting the server. See `broadcast` method if broadcast results
are not needed.
"""
if self.clients:
broadcast_id = str(uuid4())
self.broadcast_results[broadcast_id] = {}

for name, client in self.clients.items():
try:
self.broadcast_results[broadcast_id][name] = AbstractServer.NO_RESULT
client.add_request(broadcast_id, f, *args, **kwargs)
self.logger.debug2(f'Added broadcast request to execute "{f.__name__}" in {name}.')
except Exception as e:
self.broadcast_results[broadcast_id].pop(name, None)
self.logger.error(f'Error while adding broadcast request in {name}: {e}', exc_info=False)

if not self.broadcast_results[broadcast_id]:
self.broadcast_results.pop(broadcast_id, None)
else:
return broadcast_id

def broadcast_pop(self, broadcast_id):
"""Get the broadcast result of all server handlers, if ready.

Return False if `broadcast_id` exists but the requested function was not
executed in all the server handlers. Otherwise, return a dictionary
with the execution result in each server handler or True if the `broadcast_id`
is unknown.

If the dict is returned, said entry is removed from the broadcast_results dict.

Parameters
----------
broadcast_id : str
Identifier to check the status of the broadcast request.

Returns
-------
Dict, bool
False if the `broadcast_id` exists but the request was not executed in all server handlers.
True if the `broadcast_id` is unknown. Dict with results if the `broadcast_id` exists and
the results are ready, it is, the request was executed in all server handlers.
"""
for name, result in self.broadcast_results.get(broadcast_id, {}).items():
if name in self.clients and result == AbstractServer.NO_RESULT:
return False

return self.broadcast_results.pop(broadcast_id, True)

def to_dict(self) -> Dict:
"""Get basic info from AbstractServer instance.
Expand Down