diff --git a/Pipfile.lock b/Pipfile.lock index 308377c7..9632f48e 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -40,6 +40,14 @@ } }, "develop": { + "appnope": { + "hashes": [ + "sha256:5b26757dc6f79a3b7dc9fab95359328d5747fcb2409d331ea66d0272b90ab2a0", + "sha256:8b995ffe925347a2138d7ac0fe77155e4311a0ea6d6da4f5128fe4b3cbe5ed71" + ], + "markers": "sys_platform == 'darwin'", + "version": "==0.1.0" + }, "astroid": { "hashes": [ "sha256:71ea07f44df9568a75d0f354c49143a4575d90645e9fead6dfb52c26a85ed13a", @@ -47,13 +55,6 @@ ], "version": "==2.3.3" }, - "atomicwrites": { - "hashes": [ - "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", - "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6" - ], - "version": "==1.3.0" - }, "attrs": { "hashes": [ "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", @@ -198,10 +199,10 @@ }, "pluggy": { "hashes": [ - "sha256:0db4b7601aae1d35b4a033282da476845aa19185c1e6964b25cf324b5e4ec3e6", - "sha256:fa5fa1622fa6dd5c030e9cad086fa19ef6a0cf6d7a2d12318e10cb49d6d68f34" + "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", + "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" ], - "version": "==0.13.0" + "version": "==0.13.1" }, "prompt-toolkit": { "hashes": [ @@ -248,11 +249,11 @@ }, "pylint": { "hashes": [ - "sha256:7b76045426c650d2b0f02fc47c14d7934d17898779da95288a74c2a7ec440702", - "sha256:856476331f3e26598017290fd65bebe81c960e806776f324093a46b76fb2d1c0" + "sha256:3db5468ad013380e987410a8d6956226963aed94ecb5f9d3a28acca6d9ac36cd", + "sha256:886e6afc935ea2590b462664b161ca9a5e40168ea99e5300935f6591ad467df4" ], "index": "pypi", - "version": "==2.4.3" + "version": "==2.4.4" }, "pyparsing": { "hashes": [ @@ -263,11 +264,11 @@ }, "pytest": { "hashes": [ - "sha256:27abc3fef618a01bebb1f0d6d303d2816a99aa87a5968ebc32fe971be91eb1e6", - "sha256:58cee9e09242937e136dbb3dab466116ba20d6b7828c7620f23947f37eb4dae4" + "sha256:1897d74f60a5d8be02e06d708b41bf2445da2ee777066bd68edf14474fc201eb", + "sha256:f6a567e20c04259d41adce9a360bd8991e6aa29dd9695c5e6bd25a9779272673" ], "index": "pypi", - "version": "==5.2.2" + "version": "==5.3.0" }, "pytest-asyncio": { "hashes": [ diff --git a/receptor/buffers/file.py b/receptor/buffers/file.py index e300a9d9..0e22a6f7 100644 --- a/receptor/buffers/file.py +++ b/receptor/buffers/file.py @@ -27,13 +27,13 @@ def __init__(self, dir_, key, loop): pass for item in self._read_manifest(): self.q.put_nowait(item) - + async def put(self, data): ident = str(uuid.uuid4()) await self._loop.run_in_executor(pool, self._write_file, data, ident) await self.q.put(ident) await self._save_manifest() - + async def get(self, handle_only=False, delete=True): while True: msg = await self.q.get() @@ -42,21 +42,25 @@ async def get(self, handle_only=False, delete=True): return await self._get_file(msg, handle_only=handle_only, delete=delete) except FileNotFoundError: pass - + async def _save_manifest(self): async with self._manifest_lock: await self._loop.run_in_executor(pool, self._write_manifest) - + def _write_manifest(self): with open(self._manifest_path, "w") as fp: json.dump(list(self.q._queue), fp) - + def _read_manifest(self): try: with open(self._manifest_path, "r") as fp: return json.load(fp) except FileNotFoundError: return [] + except json.decoder.JSONDecodeError: + with open(self._manifest_path, "r") as fp: + logger.error("failed to decode manifest: %s", fp.read()) + raise def _path_for_ident(self, ident): return os.path.join(self._message_path, ident) diff --git a/receptor/messages/directive.py b/receptor/messages/directive.py index c2b5c7fa..175571fa 100644 --- a/receptor/messages/directive.py +++ b/receptor/messages/directive.py @@ -1,6 +1,6 @@ import datetime -import logging import json +import logging from ..exceptions import UnknownDirective from . import envelope @@ -26,7 +26,7 @@ async def __call__(self, router, inner_env): serial = 0 async for response in responses: serial += 1 - enveloped_response = envelope.InnerEnvelope.make_response( + enveloped_response = envelope.Inner.make_response( receptor=router.receptor, recipient=inner_env.sender, payload=response, diff --git a/receptor/messages/envelope.py b/receptor/messages/envelope.py index 8dabbf55..34683ad6 100644 --- a/receptor/messages/envelope.py +++ b/receptor/messages/envelope.py @@ -1,41 +1,187 @@ +import asyncio import base64 import datetime import json import logging -import uuid +import struct import time +import uuid +from enum import IntEnum logger = logging.getLogger(__name__) +MAX_INT64 = (2 ** 64 - 1) -class OuterEnvelope: - def __init__(self, frame_id, sender, recipient, route_list, inner): - self.frame_id = frame_id - self.sender = sender - self.recipient = recipient - self.route_list = route_list - self.inner = inner - self.inner_obj = None - async def deserialize_inner(self, receptor): - self.inner_obj = await InnerEnvelope.deserialize(receptor, self.inner) +class FramedMessage: + """ + A complete, two-part message. + """ + + __slots__ = ("msg_id", "header", "payload") + + def __init__(self, msg_id=None, header=None, payload=None): + if msg_id is None: + msg_id = uuid.uuid4().int + self.msg_id = msg_id + self.header = header + self.payload = payload + + def serialize(self): + h = json.dumps(self.header).encode("utf-8") + return b''.join([ + Frame.wrap(h, type_=Frame.Types.HEADER, msg_id=self.msg_id).serialize(), + h, + Frame.wrap(self.payload, msg_id=self.msg_id).serialize(), + self.payload + ]) + + +class CommandMessage(FramedMessage): + """ + A complete, single part message, meant to encapsulate point to point + commands or naive broadcasts. + """ + + def serialize(self): + h = json.dumps(self.header).encode("utf-8") + return b''.join([ + Frame.wrap(h, type_=Frame.Types.COMMAND, msg_id=self.msg_id).serialize(), + h, + ]) + + +class FramedBuffer: + """ + A buffer that accumulates frames and bytes to produce a header and a + payload. + + This buffer assumes that an entire message (denoted by msg_id) will be + sent before another message is sent. + """ + def __init__(self, loop=None): + self.q = asyncio.Queue(loop=loop) + self.header = None + self.framebuffer = bytearray() + self.bb = bytearray() + self.current_frame = None + self.to_read = 0 + + async def put(self, data): + if not self.to_read: + return await self.handle_frame(data) + await self.consume(data) + + async def handle_frame(self, data): + try: + self.framebuffer += data + frame, rest = Frame.from_data(self.framebuffer) + except struct.error: + return # We don't have enough data yet + else: + self.framebuffer = bytearray() + + if frame.type not in Frame.Types: + raise Exception("Unknown Frame Type") + + self.current_frame = frame + self.to_read = self.current_frame.length + await self.consume(rest) + + async def consume(self, data): + logger.debug("consuming %d bytes; to_read = %d bytes", len(data), self.to_read) + data, rest = data[:self.to_read], data[self.to_read:] + self.to_read -= len(data) + self.bb += data + if self.to_read == 0: + await self.finish() + if rest: + await self.handle_frame(rest) + + async def finish(self): + if self.current_frame.type == Frame.Types.HEADER: + self.header = json.loads(self.bb) + elif self.current_frame.type == Frame.Types.PAYLOAD: + await self.q.put(FramedMessage( + self.current_frame.msg_id, header=self.header, + payload=self.bb)) + self.header = None + elif self.current_frame.type == Frame.Types.COMMAND: + await self.q.put(FramedMessage( + self.current_frame.msg_id, header=json.loads(self.bb))) + else: + raise Exception("Unknown Frame Type") + self.to_read = 0 + self.bb = bytearray() + + async def get(self): + return await self.q.get() + + +class Frame: + """ + A Frame represents the minimal metadata about a transmission. + + Usually you should not create one directly, but rather use the + FramedMessage or CommandMessage classes. + """ + + class Types(IntEnum): + HEADER = 0 + PAYLOAD = 1 + COMMAND = 2 + + fmt = struct.Struct(">ccIIQQ") + + __slots__ = ('type', 'version', 'length', 'msg_id', 'id') + + def __init__(self, type_, version, length, msg_id, id_): + self.type = type_ + self.version = version + self.length = length + self.msg_id = msg_id + self.id = id_ + + def __repr__(self): + return f"Frame({self.type}, {self.version}, {self.length}, {self.msg_id}, {self.id})" - @classmethod - def from_raw(cls, raw): - doc = json.loads(raw) - return cls(**doc) - def serialize(self): - return json.dumps(dict( - frame_id=self.frame_id, - sender=self.sender, - recipient=self.recipient, - route_list=self.route_list, - inner=self.inner - )) + return self.fmt.pack( + bytes([self.type]), bytes([self.version]), + self.id, self.length, *split_uuid(self.msg_id)) + + @classmethod + def deserialize(cls, buf): + t, v, i, length, hi, lo = Frame.fmt.unpack(buf) + msg_id = join_uuid(hi, lo) + return cls(Frame.Types(ord(t)), ord(v), length, msg_id, i) + + @classmethod + def from_data(cls, data): + return cls.deserialize(data[:Frame.fmt.size]), data[Frame.fmt.size:] + + @classmethod + def wrap(cls, data, type_=Types.PAYLOAD, msg_id=None): + """ + Returns a frame for the passed data. + """ + if not msg_id: + msg_id = uuid.uuid4().int + + return cls(type_, 1, len(data), msg_id, 1) + + +def split_uuid(data): + "Splits a 128 bit int into two 64 bit words for binary encoding" + return ((data >> 64) & MAX_INT64, data & MAX_INT64) + + +def join_uuid(hi, lo): + "Joins two 64 bit words into a 128bit int from binary encoding" + return (hi << 64) | lo -class InnerEnvelope: +class Inner: def __init__(self, receptor, message_id, sender, recipient, message_type, timestamp, raw_payload, directive=None, in_response_to=None, ttl=None, serial=1, code=0, expire_time_delta=300): diff --git a/receptor/protocol.py b/receptor/protocol.py index 71be5915..3ab68520 100644 --- a/receptor/protocol.py +++ b/receptor/protocol.py @@ -1,7 +1,6 @@ import asyncio import datetime import functools -import json import logging import time import uuid @@ -12,23 +11,6 @@ logger = logging.getLogger(__name__) DELIM = b"\x1b[K" -SIZEB = b"\x1b[%dD" - - -class DataBuffer: - def __init__(self, loop=None, deserializer=json.loads): - self.q = asyncio.Queue(loop=loop) - self.data_buffer = b"" - self.deserializer = deserializer - - def add(self, data): - self.data_buffer = self.data_buffer + data - *ready, self.data_buffer = self.data_buffer.rsplit(DELIM) - for chunk in ready: - self.q.put_nowait(chunk) - - async def get(self): - return self.deserializer(await self.q.get()) class BaseProtocol(asyncio.Protocol): @@ -57,7 +39,7 @@ async def watch_queue(self): continue try: - self.transport.write(msg + DELIM) + self.transport.write(msg) except Exception: logger.exception("Error received trying to write to %s", self.id) await buffer_obj.put(msg) @@ -68,7 +50,7 @@ def connection_made(self, transport): self.peername = transport.get_extra_info('peername') self.transport = transport connected_peers_gauge.inc() - self.incoming_buffer = DataBuffer(loop=self.loop) + self.incoming_buffer = envelope.FramedBuffer(loop=self.loop) self.loop.create_task(self.wait_greeting()) def connection_lost(self, exc): @@ -76,8 +58,9 @@ def connection_lost(self, exc): self.receptor.remove_connection(self) def data_received(self, data): - logger.debug(data) - self.incoming_buffer.add(data) + # TODO: The put() call can raise an exception which should trigger a + # transport failure. + self.loop.create_task(self.incoming_buffer.put(data)) async def wait_greeting(self): ''' @@ -86,15 +69,14 @@ async def wait_greeting(self): ''' logger.debug('Looking for handshake...') data = await self.incoming_buffer.get() - logger.debug(data) - if data["cmd"] == "HI": - self.handle_handshake(data) - logger.debug("handshake complete, starting normal handle loop") + if data.header["cmd"] == "HI": + self.handle_handshake(data.header) else: logger.error("Handshake failed!") self.transport.close() def handle_handshake(self, data): + logger.debug("handle_handshake: %s", data) self.id = data["id"] self.meta = data.get("meta", {}) self.receptor.add_connection(self) @@ -102,15 +84,15 @@ def handle_handshake(self, data): self.loop.create_task(self.receptor.message_handler(self.incoming_buffer)) def send_handshake(self): - msg = json.dumps({ + msg = envelope.CommandMessage(header={ "cmd": "HI", "id": self.receptor.node_id, "expire_time": time.time() + 10, "meta": dict(capabilities=self.receptor.work_manager.get_capabilities(), groups=self.receptor.config.node_groups, work=self.receptor.work_manager.get_work()) - }).encode("utf-8") - self.transport.write(msg + DELIM) + }) + self.transport.write(msg.serialize()) class BasicProtocol(BaseProtocol): @@ -180,14 +162,14 @@ def emit_response(self, response): def _do_emit_callback(self, fut): res = fut.result() - self.transport.write(res.encode() + DELIM) + self.transport.write(res + DELIM) def data_received(self, data): recipient, directive, payload = data.rstrip(DELIM).decode('utf8').split('\n', 2) message_id = str(uuid.uuid4()) logger.info(f'{message_id}: Sending {directive} to {recipient}') sent_timestamp = datetime.datetime.utcnow() - inner_env = envelope.InnerEnvelope( + inner_env = envelope.Inner( receptor=self.receptor, message_id=message_id, sender=self.receptor.node_id, @@ -212,7 +194,7 @@ def _data_received_callback(self, inner_env, fut): try: fut.result() except Exception as e: - err_resp = envelope.InnerEnvelope.make_response( + err_resp = envelope.Inner.make_response( receptor=self.receptor, recipient=inner_env.sender, payload=str(e), diff --git a/receptor/receptor.py b/receptor/receptor.py index 0281ac0a..9c63da4b 100644 --- a/receptor/receptor.py +++ b/receptor/receptor.py @@ -93,12 +93,18 @@ def update_connections(self, protocol_obj): self.update_connection_manifest(protocol_obj.id) async def message_handler(self, buf): + logger.debug("spawning message_handler") while True: - data = await buf.get() - if "cmd" in data and data["cmd"] == "ROUTE": - await self.handle_route_advertisement(data) + try: + data = await buf.get() + except Exception: + logger.exception("message_handler") else: - await self.handle_message(data) + logger.debug("message_handler: %s", data) + if "cmd" in data.header and data.header["cmd"] == "ROUTE": + await self.handle_route_advertisement(data.header) + else: + await self.handle_message(data) def add_connection(self, protocol_obj): self.update_connections(protocol_obj) @@ -137,56 +143,57 @@ async def send_route_advertisement(self, edges=None, seen=[]): for target in destinations: buf = self.buffer_mgr.get_buffer_for_node(target, self) try: - await buf.put(json.dumps({ + msg = envelope.CommandMessage(header={ "cmd": "ROUTE", "id": self.node_id, "capabilities": self.work_manager.get_capabilities(), "groups": self.config.node_groups, "edges": edges, "seen": seens - }).encode("utf-8")) + }) + await buf.put(msg.serialize()) except Exception as e: logger.exception("Error trying to broadcast routes and capabilities: {}".format(e)) - async def handle_directive(self, outer_env): + async def handle_directive(self, inner): try: - namespace, _ = outer_env.inner_obj.directive.split(':', 1) + namespace, _ = inner.directive.split(':', 1) if namespace == RECEPTOR_DIRECTIVE_NAMESPACE: - await directive.control(self.router, outer_env.inner_obj) + await directive.control(self.router, inner) else: # other namespace/work directives - await self.work_manager.handle(outer_env.inner_obj) + await self.work_manager.handle(inner) except ValueError: - logger.error("error in handle_message: Invalid directive -> '%s'. Sending failure response back." % (outer_env.inner_obj.directive,)) - err_resp = outer_env.inner_obj.make_response( + logger.error("error in handle_message: Invalid directive -> '%s'. Sending failure response back." % (inner.directive,)) + err_resp = inner.make_response( receptor=self, - recipient=outer_env.inner_obj.sender, - payload="An invalid directive ('%s') was specified." % (outer_env.inner_obj.directive,), - in_response_to=outer_env.inner_obj.message_id, - serial=outer_env.inner_obj.serial + 1, + recipient=inner.sender, + payload="An invalid directive ('%s') was specified." % (inner.directive,), + in_response_to=inner.message_id, + serial=inner.serial + 1, ttl=15, code=1, ) await self.router.send(err_resp) except Exception as e: logger.error("error in handle_message: '%s'. Sending failure response back." % (str(e),)) - err_resp = outer_env.inner_obj.make_response( + err_resp = inner.make_response( receptor=self, - recipient=outer_env.inner_obj.sender, + recipient=inner.sender, payload=str(e), - in_response_to=outer_env.inner_obj.message_id, - serial=outer_env.inner_obj.serial + 1, + in_response_to=inner.message_id, + serial=inner.serial + 1, ttl=15, code=1, ) await self.router.send(err_resp) - async def handle_response(self, outer_env): - in_response_to = outer_env.inner_obj.in_response_to + async def handle_response(self, inner): + in_response_to = inner.in_response_to if in_response_to in self.router.response_registry: logger.info(f'Handling response to {in_response_to} with callback.') for connection in self.controller_connections: - connection.emit_response(outer_env.inner_obj) + connection.emit_response(inner) else: logger.warning(f'Received response to {in_response_to} but no record of sent message.') @@ -196,15 +203,14 @@ async def handle_message(self, msg): response=self.handle_response, ) messages_received_counter.inc() - outer_env = envelope.OuterEnvelope(**msg) - next_hop = self.router.next_hop(outer_env.recipient) + next_hop = self.router.next_hop(msg.header["recipient"]) if next_hop: - return await self.router.forward(outer_env, next_hop) + return await self.router.forward(msg, next_hop) - await outer_env.deserialize_inner(self) + inner = await envelope.Inner.deserialize(self, msg.payload) - if outer_env.inner_obj.message_type not in handlers: + if inner.message_type not in handlers: raise exceptions.UnknownMessageType( - f'Unknown message type: {outer_env.inner_obj.message_type}') + f'Unknown message type: {inner.message_type}') - await handlers[outer_env.inner_obj.message_type](outer_env) + await handlers[inner.message_type](inner) diff --git a/receptor/router.py b/receptor/router.py index b3f0ffbf..18e03b53 100644 --- a/receptor/router.py +++ b/receptor/router.py @@ -97,7 +97,7 @@ def get_nodes(self): async def ping_node(self, node_id, callback=log_ping): logger.info(f'Sending ping to node {node_id}') now = datetime.datetime.utcnow().isoformat() - ping_envelope = envelope.InnerEnvelope( + ping_envelope = envelope.Inner( receptor=self.receptor, message_id=str(uuid.uuid4()), sender=self.node_id, @@ -137,17 +137,17 @@ def find_shortest_path(self, to_node_id): mins[next_vertex] = next_total_cost heapq.heappush(heap, (next_total_cost, next_vertex, path)) - async def forward(self, outer_envelope, next_hop): + async def forward(self, msg, next_hop): """ Forward a message on to the next hop closer to its destination """ buffer_mgr = self.receptor.config.components_buffer_manager buffer_obj = buffer_mgr.get_buffer_for_node(next_hop, self.receptor) - outer_envelope.route_list.append(self.node_id) - logger.debug(f'Forwarding frame {outer_envelope.frame_id} to {next_hop}') + msg.header["route_list"].append(self.node_id) + logger.debug(f'Forwarding frame {msg.msg_id} to {next_hop}') try: route_counter.inc() - await buffer_obj.put(outer_envelope.serialize().encode("utf-8")) + await buffer_obj.put(msg.serialize()) except ReceptorBufferError as e: logger.exception("Receptor Buffer Write Error forwarding message to {}: {}".format(next_hop, e)) # TODO: Possible to find another route? This might be a hard failure @@ -177,14 +177,14 @@ async def send(self, inner_envelope, expected_response=False): #TODO: This probably needs to emit an error response raise UnrouteableError(f'No route found to {inner_envelope.recipient}') signed = await inner_envelope.sign_and_serialize() - outer_envelope = envelope.OuterEnvelope( - frame_id=str(uuid.uuid4()), - sender=self.node_id, - recipient=inner_envelope.recipient, - route_list=[self.node_id], - inner=signed - ) + + header = { + "sender": self.node_id, + "recipient": inner_envelope.recipient, + "route_list": [self.node_id] + } + msg = envelope.FramedMessage(msg_id=uuid.uuid4().int, header=header, payload=signed) logger.debug(f'Sending {inner_envelope.message_id} to {inner_envelope.recipient} via {next_node_id}') if expected_response and inner_envelope.message_type == 'directive': self.response_registry[inner_envelope.message_id] = dict(message_sent_time=inner_envelope.timestamp) - await self.forward(outer_envelope, next_node_id) + await self.forward(msg, next_node_id) diff --git a/receptor/security/__init__.py b/receptor/security/__init__.py index 1d0ce0db..8137511e 100644 --- a/receptor/security/__init__.py +++ b/receptor/security/__init__.py @@ -1,5 +1,6 @@ import json import logging + logger = logging.getLogger(__name__) @@ -26,5 +27,4 @@ async def sign_response(self, inner_envelope): for attr in ['message_id', 'sender', 'recipient', 'message_type', 'timestamp', 'raw_payload', 'directive', 'in_response_to', 'ttl', 'serial', 'code']} - ) - + ).encode("utf-8") diff --git a/receptor/tests/test_framedbuffer.py b/receptor/tests/test_framedbuffer.py new file mode 100644 index 00000000..7b38426d --- /dev/null +++ b/receptor/tests/test_framedbuffer.py @@ -0,0 +1,136 @@ +import json +import uuid + +import pytest + +from receptor.messages.envelope import Frame, FramedBuffer, FramedMessage + + +@pytest.yield_fixture +def msg_id(): + return uuid.uuid4().int + + +@pytest.yield_fixture +def framed_buffer(event_loop): + return FramedBuffer(loop=event_loop) + + +@pytest.mark.asyncio +async def test_framedbuffer(framed_buffer, msg_id): + header = {"sender": "node1", "recipient": "node2", "route_list": []} + header_bytes = json.dumps(header).encode("utf-8") + f1 = Frame(Frame.Types.HEADER, 1, len(header_bytes), msg_id, 1) + + await framed_buffer.put(f1.serialize() + header_bytes) + + payload = b"payload one is very boring" + payload2 = b"payload two is also very boring" + f2 = Frame(Frame.Types.PAYLOAD, 1, len(payload) + len(payload2), msg_id, 2) + + await framed_buffer.put(f2.serialize() + payload) + await framed_buffer.put(payload2) + + m = await framed_buffer.get() + + assert m.header == header + assert m.payload == payload + payload2 + + +@pytest.mark.asyncio +async def test_hi(msg_id, framed_buffer): + hi = json.dumps({"cmd": "hi"}).encode("utf-8") + f1 = Frame(Frame.Types.PAYLOAD, 1, len(hi), msg_id, 1) + + await framed_buffer.put(f1.serialize()) + await framed_buffer.put(hi) + + m = await framed_buffer.get() + + assert m.header is None + assert m.payload == hi + + +@pytest.mark.asyncio +async def test_extra_header(framed_buffer, msg_id): + h1 = {"sender": "node1", "recipient": "node2", "route_list": []} + payload = json.dumps(h1).encode("utf-8") + f1 = Frame(Frame.Types.HEADER, 1, len(payload), msg_id, 1) + await framed_buffer.put(f1.serialize()) + await framed_buffer.put(payload) + + h2 = {"sender": "node3", "recipient": "node4", "route_list": []} + payload = json.dumps(h2).encode("utf-8") + f2 = Frame(Frame.Types.HEADER, 1, len(payload), msg_id, 2) + await framed_buffer.put(f2.serialize()) + await framed_buffer.put(payload) + + assert framed_buffer.header == h2 + + +@pytest.mark.asyncio +async def test_command(framed_buffer, msg_id): + cmd = {"cmd": "hi"} + payload = json.dumps(cmd).encode("utf-8") + f1 = Frame(Frame.Types.COMMAND, 1, len(payload), msg_id, 1) + await framed_buffer.put(f1.serialize()) + await framed_buffer.put(payload) + + m = await framed_buffer.get() + assert m.header == cmd + assert m.payload is None + + +@pytest.mark.asyncio +async def test_overfull(framed_buffer, msg_id): + header = {"foo": "bar"} + payload = b"this is a test" + msg = FramedMessage(header=header, payload=payload) + + await framed_buffer.put(msg.serialize()) + + m = await framed_buffer.get() + + assert m.header == header + assert m.payload == payload + + +@pytest.mark.asyncio +async def test_underfull(framed_buffer, msg_id): + header = {"foo": "bar"} + payload = b"this is a test" + msg = FramedMessage(header=header, payload=payload) + b = msg.serialize() + + await framed_buffer.put(b[:10]) + await framed_buffer.put(b[10:]) + + m = await framed_buffer.get() + + assert m.header == header + assert m.payload == payload + + +@pytest.mark.asyncio +async def test_malformed_frame(framed_buffer, msg_id): + with pytest.raises(ValueError): + await framed_buffer.put( + b"this is total garbage and should break things very nicely" + ) + + +@pytest.mark.skip( + reason=""" + This test illustrates that sending an incomplete stream corrupts the transport""" +) +@pytest.mark.asyncio +async def test_too_short(framed_buffer, msg_id): + f1 = Frame(Frame.Types.HEADER, 1, 100, 1, 1) + too_short_header = b"this is not long enough" + f2 = Frame(Frame.Types.PAYLOAD, 1, 100, 1, 2) + too_short_payload = b"this is also not long enough" + + await framed_buffer.put(f1.serialize() + too_short_header) + await framed_buffer.put(f2.serialize() + too_short_payload) + + await framed_buffer.get() diff --git a/receptor/tests/test_protocol.py b/receptor/tests/test_protocol.py deleted file mode 100644 index ec293813..00000000 --- a/receptor/tests/test_protocol.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest - -from receptor import protocol - - -def deser(x): - return x - - -@pytest.mark.asyncio -async def test_databuffer(): - b = protocol.DataBuffer(deserializer=deser) - msg = b"this is a test" - s = protocol.DELIM + msg + protocol.DELIM - b.add(s) - assert await b.get() == b"" - assert await b.get() == msg - - -def test_databuffer_no_delim(): - b = protocol.DataBuffer(deserializer=deser) - msg = b"this is a test" - b.add(msg) - assert b.q.empty() - - -@pytest.mark.asyncio -async def test_databuffer_many_msgs(): - b = protocol.DataBuffer(deserializer=deser) - msg = [b"first bit", b"second bit", b"third bit unfinished"] - b.add(protocol.DELIM.join(msg)) - assert msg[0] == await b.get() - assert msg[1] == await b.get() - assert b.q.empty() diff --git a/receptor/work.py b/receptor/work.py index fd5f12e8..b747e3c8 100644 --- a/receptor/work.py +++ b/receptor/work.py @@ -1,11 +1,11 @@ import logging import traceback + import pkg_resources from . import exceptions from .messages import envelope -from .stats import work_counter, active_work_gauge - +from .stats import active_work_gauge, work_counter logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ async def handle(self, inner_env): async for response in responses: serial += 1 logger.debug(f'Response emitted for {inner_env.message_id}, serial {serial}') - enveloped_response = envelope.InnerEnvelope.make_response( + enveloped_response = envelope.Inner.make_response( receptor=self.receptor, recipient=inner_env.sender, payload=response, @@ -70,7 +70,7 @@ async def handle(self, inner_env): serial += 1 logger.error(f'Error encountered while handling the response, replying with an error message ({e})') logger.error(traceback.format_tb(e.__traceback__)) - enveloped_response = envelope.InnerEnvelope.make_response( + enveloped_response = envelope.Inner.make_response( receptor=self.receptor, recipient=inner_env.sender, payload=str(e), @@ -80,4 +80,3 @@ async def handle(self, inner_env): ) self.remove_work(inner_env) await self.receptor.router.send(enveloped_response) -