From 53536bbc95e4abaa196c9ffb96f0b029754edcfa Mon Sep 17 00:00:00 2001 From: jell-o-fishi Date: Mon, 20 Feb 2023 20:50:02 +0200 Subject: [PATCH 1/2] remove data/metadata copy in sending frames --- rsocket/frame.py | 139 +++++++++++++++++++-------- rsocket/payload.py | 6 +- rsocket/transports/tcp.py | 16 ++- tests/rsocket/helpers.py | 5 + tests/rsocket/test_frame.py | 18 +++- tests/rsocket/test_payload.py | 7 ++ tests/rsocket/test_request_stream.py | 48 ++++++++- tests/tools/fixtures_tcp.py | 7 +- 8 files changed, 191 insertions(+), 55 deletions(-) diff --git a/rsocket/frame.py b/rsocket/frame.py index d9e987b8..51bb5fec 100644 --- a/rsocket/frame.py +++ b/rsocket/frame.py @@ -57,12 +57,21 @@ class FrameType(IntEnum): class Header: __slots__ = ( 'length', + 'prefix_length', 'frame_type', 'stream_id', 'flags_ignore', - 'flags_metadata' + '_flags_metadata' ) + @property + def flags_metadata(self): + return self._flags_metadata + + @flags_metadata.setter + def flags_metadata(self, value): + self._flags_metadata = value + def is_blank(value: Optional[bytes]) -> bool: return value is None or len(value) == 0 @@ -98,6 +107,7 @@ class Frame(Header, metaclass=ABCMeta): def __init__(self, frame_type: FrameType): self.length = 0 + self.prefix_length = 0 self.frame_type = frame_type self.stream_id = CONNECTION_STREAM_ID self.metadata = b'' @@ -113,6 +123,14 @@ def __init__(self, frame_type: FrameType): self.fragment_generator = None self.sent_future: Optional[Future] = None + @property + def flags_metadata(self): + return self.metadata or self._flags_metadata + + @flags_metadata.setter + def flags_metadata(self, value): + self._flags_metadata = value + def parse_metadata(self, buffer: bytes, offset: int) -> int: if not self.flags_metadata: return 0 @@ -136,18 +154,18 @@ def parse_data(self, buffer: bytes, offset: int) -> int: def parse(self, buffer: bytes, offset: int): ... - def serialize(self, middle=b'', flags: int = 0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags: int = 0) -> bytes: flags &= ~(_FLAG_IGNORE_BIT | _FLAG_METADATA_BIT) if self.flags_ignore: flags |= _FLAG_IGNORE_BIT if self.metadata: - self.flags_metadata = True flags |= _FLAG_METADATA_BIT - self.length = self._compute_frame_length(middle) + self.length = self.compute_frame_length(middle) + self.prefix_length = self._compute_frame_prefix_length(middle) offset = 0 - buffer = bytearray(self.length) + buffer = bytearray(self.prefix_length) struct.pack_into('>I', buffer, offset, self.stream_id) offset += 4 @@ -160,11 +178,23 @@ def serialize(self, middle=b'', flags: int = 0) -> bytes: buffer[offset:offset + len(middle)] = middle[:] offset += len(middle) - if self.flags_metadata and self.metadata: - length = len(self.metadata) + if self.metadata: if not self.metadata_only: - buffer[offset:offset + 3] = pack_24bit(length) + buffer[offset:offset + 3] = pack_24bit(len(self.metadata)) offset += 3 + + return buffer + + def serialize(self, middle=b'', flags: int = 0) -> bytes: + prefix_buffer = self.serialize_frame_prefix(middle, flags) + + offset = 0 + buffer = bytearray(len(prefix_buffer) + self._compute_data_metadata_length()) + buffer[:len(prefix_buffer)] = prefix_buffer + offset += len(prefix_buffer) + + if self.flags_metadata and self.metadata: + length = len(self.metadata) buffer[offset:offset + length] = self.metadata[:] offset += length @@ -174,20 +204,37 @@ def serialize(self, middle=b'', flags: int = 0) -> bytes: return bytes(buffer) - def _compute_frame_length(self, middle: bytes) -> int: - header_length = HEADER_LENGTH - length = header_length + len(middle) + def write_data_metadata(self, writer_method): + if self.metadata: + writer_method(self.metadata) + + if not self.metadata_only and self.data: + writer_method(self.data) + + def compute_frame_length(self, middle: bytes = b'') -> int: + return self._compute_frame_prefix_length(middle) + self._compute_data_metadata_length() + + def _compute_data_metadata_length(self): + length = 0 if self.flags_metadata and self.metadata: length += len(self.metadata) - if not self.metadata_only: - length += 3 if not self.metadata_only and self.data: length += len(self.data) return length + def _compute_frame_prefix_length(self, middle: bytes = b'') -> int: + header_length = HEADER_LENGTH + length = header_length + len(middle) + + if self.flags_metadata and self.metadata: + if not self.metadata_only: + length += 3 + + return length + def __str__(self): return str(f'({FrameType(self.frame_type).name},{self.data},{self.metadata},{self.flags_complete})') @@ -262,23 +309,29 @@ def parse(self, buffer: bytes, offset: int): offset += self.parse_metadata(buffer, offset) offset += self.parse_data(buffer, offset) - def serialize(self, middle=b'', flags=0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags=0) -> bytes: flags &= ~(_FLAG_LEASE_BIT | _FLAG_RESUME_BIT) + if self.flags_lease: flags |= _FLAG_LEASE_BIT + if self.flags_resume: flags |= _FLAG_RESUME_BIT + middle = struct.pack( '>HHII', self.major_version, self.minor_version, self.keep_alive_milliseconds, self.max_lifetime_milliseconds) + if self.flags_resume: middle += struct.pack('>H', self.token_length) # assert len(self.resume_identification_token) == self.token_length # assert isinstance(self.resume_identification_token, bytes) middle += self.resume_identification_token + middle += pack_string(self.metadata_encoding) middle += pack_string(self.data_encoding) - return Frame.serialize(self, middle, flags) + + return super().serialize_frame_prefix(middle, flags) class InvalidFrame: @@ -299,9 +352,9 @@ def parse(self, buffer: bytes, offset: int): offset += 4 offset += self.parse_data(buffer, offset) - def serialize(self, middle=b'', flags=0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags=0) -> bytes: middle = struct.pack('>I', self.error_code) - return Frame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) class LeaseFrame(Frame): @@ -324,11 +377,11 @@ def parse(self, buffer: bytes, offset: int): self.number_of_requests = number_of_requests & MASK_31_BITS offset += self.parse_metadata(buffer, offset + 8) - def serialize(self, middle=b'', flags=0): + def serialize_frame_prefix(self, middle=b'', flags=0): middle = struct.pack('>II', self.time_to_live & MASK_31_BITS, self.number_of_requests & MASK_31_BITS) - return Frame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) class KeepAliveFrame(Frame): @@ -350,12 +403,12 @@ def parse(self, buffer: bytes, offset: int): offset += 8 offset += self.parse_data(buffer, offset) - def serialize(self, middle=b'', flags: int = 0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags: int = 0) -> bytes: flags &= ~_FLAG_RESPOND_BIT if self.flags_respond: flags |= _FLAG_RESPOND_BIT middle += pack_position(self.last_received_position) - return Frame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) class RequestFrame(Frame): @@ -372,13 +425,13 @@ def parse(self, buffer, offset: int) -> Tuple[int, int]: self.flags_follows = is_flag_set(flags, _FLAG_FOLLOWS_BIT) return HEADER_LENGTH, flags - def serialize(self, middle=b'', flags: int = 0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags: int = 0) -> bytes: flags &= ~_FLAG_FOLLOWS_BIT if self.flags_follows: flags |= _FLAG_FOLLOWS_BIT - return Frame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) def _parse_payload(self, buffer: bytes, offset: int): offset += self.parse_metadata(buffer, offset) @@ -423,9 +476,9 @@ def parse(self, buffer, offset): offset += 4 self._parse_payload(buffer, offset) - def serialize(self, middle=b'', flags=0): + def serialize_frame_prefix(self, middle=b'', flags=0): middle = struct.pack('>I', self.initial_request_n) - return RequestFrame.serialize(self, middle) + return super().serialize_frame_prefix(middle) class RequestChannelFrame(RequestFrame, FrameFragmentMixin): @@ -450,14 +503,15 @@ def parse(self, buffer, offset): offset += 4 self._parse_payload(buffer, offset) - def serialize(self, middle=b'', flags=0): + def serialize_frame_prefix(self, middle=b'', flags=0): middle = struct.pack('>I', self.initial_request_n) flags &= ~_FLAG_COMPLETE_BIT + if self.flags_complete: flags |= _FLAG_COMPLETE_BIT - return RequestFrame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) class RequestNFrame(RequestFrame): @@ -472,9 +526,9 @@ def parse(self, buffer, offset): offset += HEADER_LENGTH self.request_n = unpack_32bit(buffer, offset) - def serialize(self, middle=b'', flags=0): + def serialize_frame_prefix(self, middle=b'', flags=0): middle = struct.pack('>I', self.request_n) - return Frame.serialize(self, middle, flags) + return super().serialize_frame_prefix(middle, flags) class CancelFrame(Frame): @@ -508,7 +562,7 @@ def parse(self, buffer, offset): offset += self.parse_metadata(buffer, offset) offset += self.parse_data(buffer, offset) - def serialize(self, middle=b'', flags=0): + def serialize_frame_prefix(self, middle=b'', flags=0): flags &= ~(_FLAG_FOLLOWS_BIT | _FLAG_COMPLETE_BIT | _FLAG_NEXT_BIT) @@ -524,7 +578,7 @@ def serialize(self, middle=b'', flags=0): if self.flags_next: flags |= _FLAG_NEXT_BIT - return Frame.serialize(self, flags=flags) + return super().serialize_frame_prefix(flags=flags) class MetadataPushFrame(Frame): @@ -577,7 +631,7 @@ def parse(self, buffer: bytes, offset: int): offset += 8 self.first_client_position = unpack_position(buffer[offset:]) - def serialize(self, middle=b'', flags=0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags=0) -> bytes: flags &= ~(_FLAG_LEASE_BIT | _FLAG_RESUME_BIT) middle = struct.pack('>HH', self.major_version, self.minor_version) @@ -589,7 +643,7 @@ def serialize(self, middle=b'', flags=0) -> bytes: middle += pack_position(self.last_server_position) middle += pack_position(self.first_client_position) - return Frame.serialize(self, middle) + return super().serialize_frame_prefix(middle) class ResumeOKFrame(Frame): @@ -606,9 +660,9 @@ def parse(self, buffer: bytes, offset: int): offset += HEADER_LENGTH self.last_received_client_position = unpack_position(buffer[offset:offset + 8]) - def serialize(self, middle=b'', flags: int = 0) -> bytes: + def serialize_frame_prefix(self, middle=b'', flags: int = 0) -> bytes: serialized = pack_position(self.last_received_client_position) - return super().serialize(serialized) + return super().serialize_frame_prefix(serialized) class ExtendedFrame(Frame, metaclass=abc.ABCMeta): @@ -681,10 +735,10 @@ def is_fragmentable_frame(frame: Frame) -> bool: FragmentableFrame = Union[PayloadFrame, - RequestResponseFrame, - RequestChannelFrame, - RequestStreamFrame, - RequestFireAndForgetFrame] +RequestResponseFrame, +RequestChannelFrame, +RequestStreamFrame, +RequestFireAndForgetFrame] def new_frame_fragment(base_frame: FragmentableFrame, fragment: Fragment) -> Frame: @@ -745,6 +799,13 @@ def serialize_with_frame_size_header(frame: Frame) -> bytes: return full_frame +def serialize_prefix_with_frame_size_header(frame: Frame) -> bytes: + serialized_frame_prefix = frame.serialize_frame_prefix() + header = struct.pack('>I', frame.length)[1:] + full_frame = header + serialized_frame_prefix + return full_frame + + initiate_request_frame_types = (RequestResponseFrame, RequestStreamFrame, RequestChannelFrame, diff --git a/rsocket/payload.py b/rsocket/payload.py index 1fe9d0ee..6c9d6417 100644 --- a/rsocket/payload.py +++ b/rsocket/payload.py @@ -1,6 +1,6 @@ from typing import Optional -from rsocket.frame_helpers import ensure_bytes, safe_len +from rsocket.frame_helpers import safe_len from rsocket.local_typing import ByteTypes @@ -22,8 +22,8 @@ def __init__(self, data: Optional[ByteTypes] = None, metadata: Optional[ByteType self._check(data) self._check(metadata) - self.data = ensure_bytes(data) - self.metadata = ensure_bytes(metadata) + self.data = data + self.metadata = metadata def __str__(self): return f"" diff --git a/rsocket/transports/tcp.py b/rsocket/transports/tcp.py index c9c0dccc..499aa96c 100644 --- a/rsocket/transports/tcp.py +++ b/rsocket/transports/tcp.py @@ -1,6 +1,6 @@ from asyncio import StreamReader, StreamWriter -from rsocket.frame import Frame, serialize_with_frame_size_header +from rsocket.frame import Frame, serialize_prefix_with_frame_size_header from rsocket.helpers import wrap_transport_exception from rsocket.transports.transport import Transport @@ -13,14 +13,22 @@ class TransportTCP(Transport): :param writer: asyncio connection writer stream """ - def __init__(self, reader: StreamReader, writer: StreamWriter): + def __init__(self, + reader: StreamReader, + writer: StreamWriter, + read_buffer_size=1024): super().__init__() + self._read_buffer_size = read_buffer_size self._writer = writer self._reader = reader async def send_frame(self, frame: Frame): + await self.serialize_partial(frame) + + async def serialize_partial(self, frame: Frame): with wrap_transport_exception(): - self._writer.write(serialize_with_frame_size_header(frame)) + self._writer.write(serialize_prefix_with_frame_size_header(frame)) + frame.write_data_metadata(self._writer.write) await self._writer.drain() async def on_send_queue_empty(self): @@ -33,7 +41,7 @@ async def close(self): async def next_frame_generator(self): with wrap_transport_exception(): - data = await self._reader.read(1024) + data = await self._reader.read(self._read_buffer_size) if not data: self._writer.close() diff --git a/tests/rsocket/helpers.py b/tests/rsocket/helpers.py index 0fc2fc97..2d7b778b 100644 --- a/tests/rsocket/helpers.py +++ b/tests/rsocket/helpers.py @@ -1,5 +1,6 @@ import asyncio import json +import os from dataclasses import dataclass from datetime import timedelta from math import ceil @@ -102,3 +103,7 @@ def to_json_bytes(item: Any) -> bytes: def create_data(base: bytes, multiplier: int, limit: float = None): return b''.join([ensure_bytes(str(i)) + base for i in range(multiplier)])[0:limit] + + +def create_large_random_data(size: int): + return bytearray(os.urandom(size)) diff --git a/tests/rsocket/test_frame.py b/tests/rsocket/test_frame.py index b580c0f6..fc23415f 100644 --- a/tests/rsocket/test_frame.py +++ b/tests/rsocket/test_frame.py @@ -10,7 +10,7 @@ RequestResponseFrame, RequestNFrame, ResumeFrame, MetadataPushFrame, PayloadFrame, LeaseFrame, ResumeOKFrame, KeepAliveFrame, serialize_with_frame_size_header, RequestStreamFrame, RequestChannelFrame, ParseError, - parse_or_ignore, Frame, RequestFireAndForgetFrame) + parse_or_ignore, Frame, RequestFireAndForgetFrame, serialize_prefix_with_frame_size_header) from rsocket.frame_parser import FrameParser from tests.rsocket.helpers import data_bits, build_frame, bits @@ -69,6 +69,7 @@ def variable_length(): frame = await parse_frame(data, frame_parser) assert isinstance(frame, SetupFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.metadata_encoding == b'application/octet-stream' @@ -130,6 +131,7 @@ async def test_setup_with_resume(frame_parser, lease): frame = await parse_frame(data, frame_parser) assert isinstance(frame, SetupFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.metadata_encoding == b'application/octet-stream' @@ -179,6 +181,7 @@ async def test_request_stream_frame(frame_parser, follows): frame = await parse_frame(data, frame_parser) assert isinstance(frame, RequestStreamFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.data == b'\x01\x02\x03' @@ -254,6 +257,7 @@ async def test_request_channel_frame(frame_parser, follows, complete): frame = await parse_frame(data, frame_parser) assert isinstance(frame, RequestChannelFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.data == b'\x01\x02\x03' @@ -340,6 +344,7 @@ async def test_request_with_composite_metadata(frame_parser): frame = await parse_frame(data, frame_parser) assert isinstance(frame, RequestResponseFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.data == b'\x01\x02\x03' @@ -441,6 +446,7 @@ async def test_cancel(frame_parser): assert isinstance(frame, CancelFrame) assert frame.frame_type is FrameType.CANCEL + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data @@ -462,6 +468,7 @@ async def test_error(frame_parser): frame = await parse_frame(data, frame_parser) assert isinstance(frame, ErrorFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.error_code == ErrorCode.REJECTED_SETUP @@ -487,6 +494,7 @@ async def test_request_n_frame(frame_parser): frame = await parse_frame(data, frame_parser) assert isinstance(frame, RequestNFrame) + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.request_n == 23 @@ -521,6 +529,7 @@ async def test_resume_frame(frame_parser): assert frame.last_server_position == 123 assert frame.first_client_position == 456 assert frame.frame_type is FrameType.RESUME + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data @@ -542,6 +551,7 @@ async def test_metadata_push_frame(frame_parser): assert isinstance(frame, MetadataPushFrame) assert frame.metadata == b'metadata' assert frame.frame_type is FrameType.METADATA_PUSH + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data @@ -557,7 +567,6 @@ async def test_payload_frame(frame_parser, follows, complete, is_next): bits(1, 0, 'Padding'), bits(31, 6, 'Stream id'), bits(6, FrameType.PAYLOAD, 'Frame type'), - # Flags bits(1, 0, 'Ignore'), bits(1, 1, 'Metadata'), bits(1, follows, 'Follows'), @@ -576,6 +585,7 @@ async def test_payload_frame(frame_parser, follows, complete, is_next): assert frame.data == b'actual_data' assert frame.frame_type is FrameType.PAYLOAD assert frame.stream_id == 6 + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.flags_follows is bool(follows) @@ -603,6 +613,7 @@ async def test_payload_without_body(frame_parser): assert isinstance(frame, PayloadFrame) assert frame.frame_type is FrameType.PAYLOAD assert frame.stream_id == 6 + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data assert frame.flags_follows is False @@ -634,6 +645,7 @@ async def test_lease_frame(frame_parser): assert frame.time_to_live == 456 assert frame.frame_type is FrameType.LEASE assert frame.metadata == b'Metadata on lease frame' + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data @@ -674,6 +686,7 @@ async def test_resume_ok_frame(frame_parser): assert isinstance(frame, ResumeOKFrame) assert frame.last_received_client_position == 456 assert frame.frame_type is FrameType.RESUME_OK + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data @@ -699,6 +712,7 @@ async def test_keepalive_frame(frame_parser): assert frame.stream_id == 0 assert frame.frame_type is FrameType.KEEPALIVE assert frame.data == b'additional data' + assert serialize_prefix_with_frame_size_header(frame) == data[:frame.prefix_length + 3] assert serialize_with_frame_size_header(frame) == data diff --git a/tests/rsocket/test_payload.py b/tests/rsocket/test_payload.py index 20e18b8a..fb091017 100644 --- a/tests/rsocket/test_payload.py +++ b/tests/rsocket/test_payload.py @@ -30,3 +30,10 @@ def test_payload_support_bytearray(): assert payload.data == b'\x01\x05\x0a' assert payload.metadata == b'\x04\x06\x07' + + +def test_payload_cannot_accept_strings(): + with pytest.raises(AssertionError) as exc_info: + Payload('data') + + assert isinstance(exc_info.value, AssertionError) diff --git a/tests/rsocket/test_request_stream.py b/tests/rsocket/test_request_stream.py index a63914f0..2749d28a 100644 --- a/tests/rsocket/test_request_stream.py +++ b/tests/rsocket/test_request_stream.py @@ -18,7 +18,8 @@ from rsocket.rsocket_server import RSocketServer from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator from rsocket.streams.stream_from_generator import StreamFromGenerator -from tests.rsocket.helpers import get_components +from tests.rsocket.helpers import get_components, create_large_random_data +from tests.tools.helpers import measure_time @pytest.mark.parametrize('complete_inline', ( @@ -27,6 +28,7 @@ )) async def test_request_stream_properly_finished(pipe: Tuple[RSocketServer, RSocketClient], complete_inline): server, client = get_components(pipe) + response_count = 3 class Handler(BaseRequestHandler): @@ -34,23 +36,59 @@ async def request_stream(self, payload: Payload) -> Publisher: return StreamFromAsyncGenerator(self.feed) async def feed(self): - for x in range(3): + for x in range(response_count): value = Payload('Feed Item: {}'.format(x).encode('utf-8')) - yield value, complete_inline and x == 2 + yield value, complete_inline and x == response_count - 1 if not complete_inline: yield None, True server.set_handler_using_factory(Handler) - result = await AwaitableRSocket(client).request_stream(Payload()) + measured_result = await measure_time(AwaitableRSocket(client).request_stream(Payload())) - assert len(result) == 3 + logging.info(measured_result.delta / (response_count if complete_inline else response_count + 1)) + + result = measured_result.result + assert len(result) == response_count assert result[0].data == b'Feed Item: 0' assert result[1].data == b'Feed Item: 1' assert result[2].data == b'Feed Item: 2' +@pytest.mark.parametrize('complete_inline', ( + True, + False, +)) +async def test_request_stream_properly_finished_performance(pipe_tcp: Tuple[RSocketServer, RSocketClient], + complete_inline): + server, client = get_components(pipe_tcp) + response_count = 300 + sample_data = create_large_random_data(1920 * 1080 * 3) + + class Handler(BaseRequestHandler): + + async def request_stream(self, payload: Payload) -> Publisher: + return StreamFromAsyncGenerator(self.feed) + + async def feed(self): + for x in range(response_count): + value = Payload(sample_data) + yield value, complete_inline and x == response_count - 1 + + if not complete_inline: + yield None, True + + server.set_handler_using_factory(Handler) + + measured_result = await measure_time(AwaitableRSocket(client).request_stream(Payload())) + + logging.info(measured_result.delta / 2 / (response_count if complete_inline else response_count + 1)) + + result = measured_result.result + assert len(result) == response_count + + @pytest.mark.parametrize('initial_request_n', ( 0, -1, diff --git a/tests/tools/fixtures_tcp.py b/tests/tools/fixtures_tcp.py index 9d6ae137..bfed63ae 100644 --- a/tests/tools/fixtures_tcp.py +++ b/tests/tools/fixtures_tcp.py @@ -14,10 +14,12 @@ @asynccontextmanager async def pipe_factory_tcp(unused_tcp_port, client_arguments=None, server_arguments=None, auto_connect_client=True): wait_for_server = Event() + read_buffer_size = 2000 * 1080 * 3 def session(*connection): nonlocal server - server = RSocketServer(TransportTCP(*connection), **(server_arguments or {})) + + server = RSocketServer(TransportTCP(*connection, read_buffer_size=read_buffer_size), **(server_arguments or {})) wait_for_server.set() async def start(): @@ -30,7 +32,8 @@ async def start(): # client_arguments.update(test_overrides) - client = RSocketClient(single_transport_provider(TransportTCP(*connection)), **(client_arguments or {})) + client = RSocketClient(single_transport_provider(TransportTCP(*connection, read_buffer_size=read_buffer_size)), + **(client_arguments or {})) if auto_connect_client: await client.connect() From 57a9fd1ad83d4039faddba84f0223b0d520023f5 Mon Sep 17 00:00:00 2001 From: jell-o-fishi Date: Tue, 21 Feb 2023 18:37:10 +0200 Subject: [PATCH 2/2] performance test code --- performance/conftest.py | 6 +++++- performance/performance_client.py | 32 ++++++++++++++++++++----------- performance/performance_server.py | 14 ++++++++++++-- performance/test_performance.py | 8 ++++++++ 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/performance/conftest.py b/performance/conftest.py index 7fbf56cf..59b29b39 100644 --- a/performance/conftest.py +++ b/performance/conftest.py @@ -1,6 +1,10 @@ import logging +def pytest_configure(config): + config.addinivalue_line("markers", "performance: marks performance tests") + + def setup_logging(level=logging.DEBUG, use_file: bool = False): formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -19,4 +23,4 @@ def setup_logging(level=logging.DEBUG, use_file: bool = False): logging.basicConfig(level=level, handlers=handlers) -setup_logging(logging.ERROR) +setup_logging(logging.DEBUG) diff --git a/performance/performance_client.py b/performance/performance_client.py index a7177f15..28cd4017 100644 --- a/performance/performance_client.py +++ b/performance/performance_client.py @@ -13,25 +13,28 @@ from rsocket.rsocket_client import RSocketClient from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator from rsocket.transports.tcp import TransportTCP -from tests.rsocket.helpers import to_json_bytes +from tests.rsocket.helpers import to_json_bytes, create_large_random_data + + +data_size = 1920 # * 1080 * 3 +large_data = create_large_random_data(data_size) def sample_publisher(wait_for_requester_complete: Event, - response_count: int = 3) -> Publisher: + response_count: int = 3, + data_generator=lambda index: ('Item to server from client on channel: %s' % index).encode('utf-8') + ) -> Publisher: async def generator() -> AsyncGenerator[Tuple[Fragment, bool], None]: - current_response = 0 for i in range(response_count): - is_complete = (current_response + 1) == response_count + is_complete = (i + 1) == response_count - message = 'Item to server from client on channel: %s' % current_response - yield Fragment(message.encode('utf-8')), is_complete + message = data_generator(i) + yield Payload(message), is_complete if is_complete: wait_for_requester_complete.set() break - current_response += 1 - return StreamFromAsyncGenerator(generator) @@ -47,13 +50,20 @@ async def request_response(self): return await self._client.request_response(payload) + async def large_request(self): + payload = Payload(large_data, composite( + route('large'), + authenticate_simple('user', '12345') + )) + + return await self._client.request_response(payload) + async def request_channel(self): - requester_completion_event = Event() payload = Payload(b'The quick brown fox', composite( route('channel'), authenticate_simple('user', '12345') )) - publisher = sample_publisher(requester_completion_event) + publisher = sample_publisher(Event()) return await self._client.request_channel(payload, publisher, limit_rate=5) @@ -97,7 +107,7 @@ async def __aenter__(self): connection = await asyncio.open_connection('localhost', self._server_port) self._client = AwaitableRSocket(RSocketClient( - single_transport_provider(TransportTCP(*connection)), + single_transport_provider(TransportTCP(*connection, read_buffer_size=data_size + 3000)), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) ) diff --git a/performance/performance_server.py b/performance/performance_server.py index e234c4f2..756ddcc9 100644 --- a/performance/performance_server.py +++ b/performance/performance_server.py @@ -15,6 +15,10 @@ from rsocket.rsocket_server import RSocketServer from rsocket.transports.tcp import TransportTCP from performance.sample_responses import response_stream_2, response_stream_1, LoggingSubscriber +from tests.rsocket.helpers import create_large_random_data + +data_size = 1920 # * 1080 * 3 +large_data = create_large_random_data(data_size) router = RequestRouter() @@ -34,6 +38,12 @@ async def single_request_response(payload, composite_metadata): return create_future(Payload(b'single_response')) +@router.response('large') +async def single_request_response(payload, composite_metadata): + logging.info('Got single request') + return create_future(Payload(large_data)) + + @router.response('last_fnf') async def get_last_fnf(): logging.info('Got single request') @@ -102,9 +112,9 @@ def handler_factory(): def client_handler_factory(on_ready=None): def handle_client(reader, writer): - RSocketServer(TransportTCP(reader, writer), + RSocketServer(TransportTCP(reader, writer, read_buffer_size=data_size + 3000), handler_factory=handler_factory, - fragment_size_bytes=64_000, + # fragment_size_bytes=64_000, on_ready=on_ready ) diff --git a/performance/test_performance.py b/performance/test_performance.py index 375444c3..227993a6 100644 --- a/performance/test_performance.py +++ b/performance/test_performance.py @@ -8,6 +8,7 @@ from performance.performance_client import PerformanceClient from performance.performance_server import run_server from rsocket.rsocket_server import RSocketServer +from tests.tools.helpers import measure_time @pytest.mark.timeout(5) @@ -31,6 +32,13 @@ async def test_request_stream(unused_tcp_port): assert result is not None +@pytest.mark.performance +async def test_large_request(): + async with run_with_client(6565) as client: + result = await measure_time(client.large_request()) + print(result.delta) + + @asynccontextmanager async def run_against_server(unused_tcp_port: int) -> PerformanceClient: server_ready = asyncio.Event()