From 7348fb9460313955313bdc8c347245b1aaa87508 Mon Sep 17 00:00:00 2001 From: "gabis@precog.co" Date: Mon, 14 Mar 2022 21:40:49 +0200 Subject: [PATCH 01/10] initial implementation reconnect support --- rsocket/awaitable/awaitable_rsocket.py | 4 +- .../load_balancer/load_balancer_rsocket.py | 6 +- .../load_balancer/load_balancer_strategy.py | 2 +- rsocket/load_balancer/random_client.py | 4 +- rsocket/load_balancer/round_robin.py | 4 +- rsocket/request_handler.py | 9 +- rsocket/rsocket.py | 2 +- rsocket/rsocket_base.py | 86 ++++++++++++------- rsocket/rsocket_client.py | 27 ++++-- rsocket/rsocket_internal.py | 4 + rsocket/rsocket_server.py | 37 ++++++++ rsocket/rx_support/rx_rsocket.py | 4 +- rsocket/transports/aiohttp_websocket.py | 46 ++++++++-- rsocket/transports/quart_websocket.py | 3 + rsocket/transports/tcp.py | 3 + rsocket/transports/transport.py | 4 + tests/conftest.py | 23 +++-- tests/rsocket/test_load_balancer.py | 14 ++- tests/rsocket/test_resume_unsupported.py | 5 +- tests/rsocket/test_rsocket.py | 6 +- tests/rx_support/test_rx_support.py | 6 +- 21 files changed, 223 insertions(+), 76 deletions(-) diff --git a/rsocket/awaitable/awaitable_rsocket.py b/rsocket/awaitable/awaitable_rsocket.py index f09af8e5..bfa2d57e 100644 --- a/rsocket/awaitable/awaitable_rsocket.py +++ b/rsocket/awaitable/awaitable_rsocket.py @@ -47,8 +47,8 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self._rsocket.__aexit__(exc_type, exc_val, exc_tb) - def connect(self): - return self._rsocket.connect() + async def connect(self): + return await self._rsocket.connect() def close(self): self._rsocket.close() diff --git a/rsocket/load_balancer/load_balancer_rsocket.py b/rsocket/load_balancer/load_balancer_rsocket.py index d5e8511e..c4e38caa 100644 --- a/rsocket/load_balancer/load_balancer_rsocket.py +++ b/rsocket/load_balancer/load_balancer_rsocket.py @@ -30,14 +30,14 @@ def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: def metadata_push(self, metadata: bytes): self._select_client().metadata_push(metadata) - def connect(self): - self._strategy.connect() + async def connect(self): + await self._strategy.connect() async def close(self): await self._strategy.close() async def __aenter__(self) -> RSocket: - self._strategy.connect() + await self._strategy.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/rsocket/load_balancer/load_balancer_strategy.py b/rsocket/load_balancer/load_balancer_strategy.py index 152adfdd..a707775e 100644 --- a/rsocket/load_balancer/load_balancer_strategy.py +++ b/rsocket/load_balancer/load_balancer_strategy.py @@ -9,7 +9,7 @@ def select(self) -> RSocket: ... @abc.abstractmethod - def connect(self): + async def connect(self): ... @abc.abstractmethod diff --git a/rsocket/load_balancer/random_client.py b/rsocket/load_balancer/random_client.py index 9ac6d4f5..16a07b93 100644 --- a/rsocket/load_balancer/random_client.py +++ b/rsocket/load_balancer/random_client.py @@ -20,9 +20,9 @@ def select(self) -> RSocket: random_client_id = random.randint(0, len(self._pool)) return self._pool[random_client_id] - def connect(self): + async def connect(self): if self._auto_connect: - [client.connect() for client in self._pool] + [await client.connect() for client in self._pool] async def close(self): if self._auto_close: diff --git a/rsocket/load_balancer/round_robin.py b/rsocket/load_balancer/round_robin.py index 5b7681be..c8ee1d8c 100644 --- a/rsocket/load_balancer/round_robin.py +++ b/rsocket/load_balancer/round_robin.py @@ -20,9 +20,9 @@ def select(self) -> RSocket: self._current_index = (self._current_index + 1) % len(self._pool) return client - def connect(self): + async def connect(self): if self._auto_connect: - [client.connect() for client in self._pool] + [await client.connect() for client in self._pool] async def close(self): if self._auto_close: diff --git a/rsocket/request_handler.py b/rsocket/request_handler.py index e364f3cc..f4166f16 100644 --- a/rsocket/request_handler.py +++ b/rsocket/request_handler.py @@ -58,7 +58,11 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): @abstractmethod async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): + rsocket): + ... + + @abstractmethod + async def on_connection_lost(self, rsocket): ... def _parse_composite_metadata(self, metadata: bytes) -> CompositeMetadata: @@ -93,6 +97,9 @@ async def request_stream(self, payload: Payload) -> Publisher: async def on_error(self, error_code: ErrorCode, payload: Payload): logger().error('Error: %s, %s', error_code, payload) + async def on_connection_lost(self, rsocket): + pass + async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, cancel_all_streams: Callable): diff --git a/rsocket/rsocket.py b/rsocket/rsocket.py index 78ef4448..b3a80b6b 100644 --- a/rsocket/rsocket.py +++ b/rsocket/rsocket.py @@ -33,7 +33,7 @@ def metadata_push(self, metadata: bytes): ... @abc.abstractmethod - def connect(self): + async def connect(self): ... @abc.abstractmethod diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index 559495e9..a19b9c2c 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -10,11 +10,15 @@ from rsocket.error_codes import ErrorCode from rsocket.exceptions import RSocketProtocolException from rsocket.extensions.mimetypes import WellKnownMimeTypes -from rsocket.frame import KeepAliveFrame, \ - MetadataPushFrame, RequestFireAndForgetFrame, RequestResponseFrame, \ - RequestStreamFrame, Frame, exception_to_error_frame, LeaseFrame, ErrorFrame, RequestFrame, \ - initiate_request_frame_types, InvalidFrame, FragmentableFrame -from rsocket.frame import RequestChannelFrame, ResumeFrame, is_fragmentable_frame, CONNECTION_STREAM_ID +from rsocket.frame import (KeepAliveFrame, + MetadataPushFrame, RequestFireAndForgetFrame, + RequestResponseFrame, RequestStreamFrame, Frame, + exception_to_error_frame, + LeaseFrame, ErrorFrame, RequestFrame, + initiate_request_frame_types, InvalidFrame, + FragmentableFrame) +from rsocket.frame import (RequestChannelFrame, ResumeFrame, + is_fragmentable_frame, CONNECTION_STREAM_ID) from rsocket.frame import SetupFrame from rsocket.frame_builders import to_payload_frame, to_fire_and_forget_frame from rsocket.frame_fragment_cache import FrameFragmentCache @@ -56,7 +60,6 @@ def on_next(self, value, is_complete=False): self._socket.send_lease(value) def __init__(self, - transport: Transport, handler_factory: Type[RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, @@ -68,33 +71,17 @@ def __init__(self, setup_payload: Optional[Payload] = None ): - self._transport = transport - self._handler = handler_factory(self) - self._stream_control = StreamControl(self._get_first_stream_id()) + self._handler_factory = handler_factory + self._request_queue_size = request_queue_size self._honor_lease = honor_lease self._max_lifetime_period = max_lifetime_period self._keep_alive_period = keep_alive_period self._setup_payload = setup_payload - self._data_encoding = self._ensure_encoding_name(data_encoding) self._metadata_encoding = self._ensure_encoding_name(metadata_encoding) - self._is_closing = False - - if self._honor_lease: - self._requester_lease = DefinedLease(maximum_request_count=0) - else: - self._requester_lease = NullLease() - - self._responder_lease = NullLease() self._lease_publisher = lease_publisher - - self._frame_fragment_cache = FrameFragmentCache() - - self._send_queue = asyncio.Queue() - self._request_queue = asyncio.Queue(request_queue_size) - - self._receiver_task = self._start_task_if_not_closing(self._receiver) - self._sender_task = self._start_task_if_not_closing(self._sender) + self._sender_task = None + self._receiver_task = None self._async_frame_handler_by_type: Dict[Type[Frame], Any] = { RequestResponseFrame: self.handle_request_response, @@ -109,7 +96,38 @@ def __init__(self, ErrorFrame: self.handle_error } - def connect(self): + self._setup_internals() + + def _setup_internals(self): + pass + + @abc.abstractmethod + def _current_transport(self) -> Transport: + ... + + def _reset_internals(self): + self._frame_fragment_cache = FrameFragmentCache() + self._send_queue = asyncio.Queue() + self._request_queue = asyncio.Queue(self._request_queue_size) + + if self._honor_lease: + self._requester_lease = DefinedLease(maximum_request_count=0) + else: + self._requester_lease = NullLease() + + self._responder_lease = NullLease() + self._stream_control = StreamControl(self._get_first_stream_id()) + self._handler = self._handler_factory(self) + self._is_closing = False + + def close_all_streams(self): + self._stream_control.cancel_all_handlers() + + def _start_tasks(self): + self._receiver_task = self._start_task_if_not_closing(self._receiver) + self._sender_task = self._start_task_if_not_closing(self._sender) + + async def connect(self): logger().debug('%s: sending setup frame', self._log_identifier()) self.send_frame(self._create_setup_frame(self._data_encoding, @@ -131,6 +149,9 @@ def _start_task_if_not_closing(self, task_factory: Callable[[], Coroutine]) -> O if not self._is_closing: return asyncio.create_task(task_factory()) + def set_handler_factory(self, handler_factory): + self._handler_factory = handler_factory + def set_handler_using_factory(self, handler_factory) -> RequestHandler: self._handler = handler_factory(self) return self._handler @@ -321,7 +342,7 @@ async def _receiver_listen(self): while self.is_server_alive(): - next_frame_generator = await self._transport.next_frame_generator(self.is_server_alive()) + next_frame_generator = await self._current_transport().next_frame_generator(self.is_server_alive()) if next_frame_generator is None: break async for frame in next_frame_generator: @@ -382,11 +403,11 @@ async def _sender(self): while self.is_server_alive(): frame = await self._send_queue.get() - await self._transport.send_frame(frame) + await self._current_transport().send_frame(frame) self._send_queue.task_done() if self._send_queue.empty(): - await self._transport.on_send_queue_empty() + await self._current_transport().on_send_queue_empty() except ConnectionResetError as exception: logger().debug(str(exception)) except asyncio.CancelledError: @@ -401,7 +422,10 @@ async def close(self): self._is_closing = True await self._cancel_if_task_exists(self._sender_task) await self._cancel_if_task_exists(self._receiver_task) - await self._transport.close() + transport = self._current_transport() + + if transport is not None: + await transport.close() async def _cancel_if_task_exists(self, task): if task is not None: diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index df811aa2..7cf927df 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -1,6 +1,6 @@ import asyncio from datetime import timedelta, datetime -from typing import Optional, Type +from typing import Optional, Type, Callable, Coroutine from typing import Union from reactivestreams.publisher import Publisher @@ -16,7 +16,7 @@ class RSocketClient(RSocketBase): def __init__(self, - transport: Transport, + transport_provider: Callable[[], Coroutine[None, None, Transport]], handler_factory: Type[RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, @@ -25,13 +25,14 @@ def __init__(self, metadata_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, keep_alive_period: timedelta = timedelta(milliseconds=500), max_lifetime_period: timedelta = timedelta(minutes=10), - setup_payload: Optional[Payload] = None + setup_payload: Optional[Payload] = None, ): + self._transport_provider = transport_provider self._is_server_alive = True self._update_last_keepalive() + self._transport: Optional[Transport] = None - super().__init__(transport, - handler_factory=handler_factory, + super().__init__(handler_factory=handler_factory, honor_lease=honor_lease, lease_publisher=lease_publisher, request_queue_size=request_queue_size, @@ -41,11 +42,23 @@ def __init__(self, max_lifetime_period=max_lifetime_period, setup_payload=setup_payload) + def _current_transport(self) -> Transport: + return self._transport + def _log_identifier(self) -> str: return 'client' + async def connect(self): + await self.close() + self._is_closing = False + self._transport = await self._transport_provider() + await self._transport.connect() + self._reset_internals() + self._start_tasks() + return await super().connect() + async def __aenter__(self) -> 'RSocketClient': - self.connect() + await self.connect() return self def _get_first_stream_id(self) -> int: @@ -82,7 +95,7 @@ async def _keepalive_timeout_task(self): self._is_server_alive = False await self._handler.on_keepalive_timeout( time_since_last_keepalive, - self._stream_control.cancel_all_handlers + self ) except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: keepalive_timeout', self._log_identifier()) diff --git a/rsocket/rsocket_internal.py b/rsocket/rsocket_internal.py index bdd38430..091a4085 100644 --- a/rsocket/rsocket_internal.py +++ b/rsocket/rsocket_internal.py @@ -28,3 +28,7 @@ def send_payload(self, stream_id: int, payload: Payload, complete=False, is_next @abc.abstractmethod def send_error(self, stream_id: int, exception: Exception): ... + + @abc.abstractmethod + def close_all_streams(self): + ... diff --git a/rsocket/rsocket_server.py b/rsocket/rsocket_server.py index e62cd346..2709687c 100644 --- a/rsocket/rsocket_server.py +++ b/rsocket/rsocket_server.py @@ -1,8 +1,45 @@ +from datetime import timedelta +from typing import Type, Optional, Union + +from reactivestreams.publisher import Publisher +from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.payload import Payload +from rsocket.request_handler import RequestHandler, BaseRequestHandler from rsocket.rsocket_base import RSocketBase +from rsocket.transports.transport import Transport class RSocketServer(RSocketBase): + def __init__(self, + transport: Transport, + handler_factory: Type[RequestHandler] = BaseRequestHandler, + honor_lease=False, + lease_publisher: Optional[Publisher] = None, + request_queue_size: int = 0, + data_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + metadata_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + keep_alive_period: timedelta = timedelta(milliseconds=500), + max_lifetime_period: timedelta = timedelta(minutes=10), + setup_payload: Optional[Payload] = None): + super().__init__(handler_factory, + honor_lease, + lease_publisher, + request_queue_size, + data_encoding, + metadata_encoding, + keep_alive_period, + max_lifetime_period, + setup_payload) + self._transport = transport + + def _current_transport(self) -> Transport: + return self._transport + + def _setup_internals(self): + self._reset_internals() + self._start_tasks() + def _log_identifier(self) -> str: return 'server' diff --git a/rsocket/rx_support/rx_rsocket.py b/rsocket/rx_support/rx_rsocket.py index f394c970..c3073ac8 100644 --- a/rsocket/rx_support/rx_rsocket.py +++ b/rsocket/rx_support/rx_rsocket.py @@ -41,8 +41,8 @@ def fire_and_forget(self, request: Payload): def metadata_push(self, metadata: bytes): self._rsocket.metadata_push(metadata) - def connect(self): - return self._rsocket.connect() + async def connect(self): + return await self._rsocket.connect() async def close(self): await self._rsocket.close() diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index bba0793d..7507f3e8 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -13,15 +13,11 @@ @asynccontextmanager async def websocket_client(url, *args, **kwargs) -> RSocketClient: - async with aiohttp.ClientSession() as session: - async with session.ws_connect(url) as ws: - transport = TransportAioHttpWebsocket(ws) - message_handler = asyncio.create_task(transport.handle_incoming_ws_messages()) - async with RSocketClient(transport, *args, **kwargs) as client: - yield client + async def transport_provider(): + return TransportAioHttpClient(url) - message_handler.cancel() - await message_handler + async with RSocketClient(transport_provider, *args, **kwargs) as client: + yield client def websocket_handler_factory(*args, on_server_create=None, **kwargs): @@ -40,11 +36,45 @@ async def websocket_handler(request): return websocket_handler +class TransportAioHttpClient(AbstractWebsocketTransport): + + def __init__(self, url): + super().__init__() + self._url = url + + async def connect(self): + self._session = aiohttp.ClientSession() + self._ws_context = self._session.ws_connect(self._url) + self._ws = await self._ws_context.__aenter__() + self._message_handler = asyncio.create_task(self.handle_incoming_ws_messages()) + + async def handle_incoming_ws_messages(self): + try: + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.BINARY: + async for frame in self._frame_parser.receive_data(msg.data, 0): + self._incoming_frame_queue.put_nowait(frame) + except asyncio.CancelledError: + logger().debug('Asyncio task canceled: aiohttp_handle_incoming_ws_messages') + + async def send_frame(self, frame: Frame): + await self._ws.send_bytes(frame.serialize()) + + async def close(self): + await self._ws_context.__aexit__(None, None, None) + await self._session.__aexit__(None, None, None) + self._message_handler.cancel() + await self._message_handler + + class TransportAioHttpWebsocket(AbstractWebsocketTransport): def __init__(self, websocket): super().__init__() self._ws = websocket + async def connect(self): + pass + async def handle_incoming_ws_messages(self): try: async for msg in self._ws: diff --git a/rsocket/transports/quart_websocket.py b/rsocket/transports/quart_websocket.py index 7fdf0a65..3cefc935 100644 --- a/rsocket/transports/quart_websocket.py +++ b/rsocket/transports/quart_websocket.py @@ -20,6 +20,9 @@ async def websocket_handler(*args, on_server_create=None, **kwargs): class TransportQuartWebsocket(AbstractWebsocketTransport): + async def connect(self): + pass + async def handle_incoming_ws_messages(self): try: while True: diff --git a/rsocket/transports/tcp.py b/rsocket/transports/tcp.py index f99a16bf..17ff5697 100644 --- a/rsocket/transports/tcp.py +++ b/rsocket/transports/tcp.py @@ -11,6 +11,9 @@ def __init__(self, reader: StreamReader, writer: StreamWriter): self._writer = writer self._reader = reader + async def connect(self): + pass + async def send_frame(self, frame: Frame): self._writer.write(serialize_with_frame_size_header(frame)) diff --git a/rsocket/transports/transport.py b/rsocket/transports/transport.py index 29b6badf..030b486d 100644 --- a/rsocket/transports/transport.py +++ b/rsocket/transports/transport.py @@ -9,6 +9,10 @@ class Transport(metaclass=abc.ABCMeta): def __init__(self): self._frame_parser = FrameParser() + @abc.abstractmethod + async def connect(self): + ... + @abc.abstractmethod async def send_frame(self, frame: Frame): ... diff --git a/tests/conftest.py b/tests/conftest.py index d93a762c..b715294a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,17 +95,21 @@ def session(*connection): async def start(): nonlocal service, client service = await asyncio.start_server(session, host, port) - connection = await asyncio.open_connection(host, port) nonlocal client_arguments # test_overrides = {'keep_alive_period': timedelta(minutes=20)} client_arguments = client_arguments or {} + # client_arguments.update(test_overrides) - client = RSocketClient(TransportTCP(*connection), **(client_arguments or {})) + async def transport_provider(): + connection = await asyncio.open_connection(host, port) + return TransportTCP(*connection) + + client = RSocketClient(transport_provider, **(client_arguments or {})) if auto_connect_client: - client.connect() + await client.connect() async def finish(): if auto_connect_client: @@ -122,9 +126,18 @@ async def finish(): host = 'localhost' await start() - await wait_for_server.wait() + + async def server_provider(): + await wait_for_server.wait() + return server + try: - yield server, client + if auto_connect_client: + await wait_for_server.wait() + yield server, client + else: + yield server_provider, client + assert_no_open_streams(client, server) finally: await finish() diff --git a/tests/rsocket/test_load_balancer.py b/tests/rsocket/test_load_balancer.py index 7ef62f84..91e5667c 100644 --- a/tests/rsocket/test_load_balancer.py +++ b/tests/rsocket/test_load_balancer.py @@ -9,6 +9,14 @@ from tests.rsocket.helpers import future_from_request +class HandlerFactory(): + def __init__(self, server_id: int): + self._server_id = server_id + + def factory(self, socket): + return Handler(socket, self._server_id) + + class Handler(BaseRequestHandler): def __init__(self, socket, server_id: int): super().__init__(socket) @@ -19,7 +27,6 @@ async def request_response(self, request: Payload): async def test_load_balancer_round_robin(unused_tcp_port_factory): - servers = [] clients = [] server_count = 3 request_count = 7 @@ -27,11 +34,10 @@ async def test_load_balancer_round_robin(unused_tcp_port_factory): async with AsyncExitStack() as stack: for i in range(server_count): tcp_port = unused_tcp_port_factory() - server, client = await stack.enter_async_context( + _, client = await stack.enter_async_context( pipe_factory_tcp(tcp_port, - server_arguments={'handler_factory': lambda socket: Handler(socket, i)}, + server_arguments={'handler_factory': HandlerFactory(i).factory}, auto_connect_client=False)) - servers.append(server) clients.append(client) round_robin = LoadBalancerRoundRobin(clients) diff --git a/tests/rsocket/test_resume_unsupported.py b/tests/rsocket/test_resume_unsupported.py index a4808e1c..366848c1 100644 --- a/tests/rsocket/test_resume_unsupported.py +++ b/tests/rsocket/test_resume_unsupported.py @@ -15,7 +15,7 @@ @pytest.mark.allow_error_log async def test_setup_resume_unsupported(pipe_tcp_without_auto_connect: Tuple[RSocketServer, RSocketClient]): - server, client = pipe_tcp_without_auto_connect + _, client = pipe_tcp_without_auto_connect received_error_code = None error_received = asyncio.Event() @@ -25,7 +25,8 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): received_error_code = error_code error_received.set() - client.set_handler_using_factory(Handler) + client.set_handler_factory(Handler) + await client.connect() bad_client = MisbehavingRSocket(client._transport) setup = SetupFrame() diff --git a/tests/rsocket/test_rsocket.py b/tests/rsocket/test_rsocket.py index 4dad0e85..9932fbf3 100644 --- a/tests/rsocket/test_rsocket.py +++ b/tests/rsocket/test_rsocket.py @@ -1,7 +1,6 @@ import asyncio import logging from datetime import timedelta -from typing import Callable import pytest @@ -10,6 +9,7 @@ from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_internal import RSocketInternal async def test_rsocket_client_closed_without_requests(lazy_pipe): @@ -46,8 +46,8 @@ class ClientHandler(BaseRequestHandler): async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): - cancel_all_streams() + socket: RSocketInternal): + socket.close_all_streams() async with lazy_pipe( client_arguments={ diff --git a/tests/rx_support/test_rx_support.py b/tests/rx_support/test_rx_support.py index 820fc0a3..87376b82 100644 --- a/tests/rx_support/test_rx_support.py +++ b/tests/rx_support/test_rx_support.py @@ -172,10 +172,12 @@ class Handler(BaseRequestHandler): async def request_response(self, payload: Payload) -> Future: return create_future(Payload(b'Response')) - server, client = pipe_tcp_without_auto_connect - server.set_handler_using_factory(Handler) + server_provider, client = pipe_tcp_without_auto_connect async with RxRSocket(client) as rx_client: + server = await server_provider() + server.set_handler_using_factory(Handler) + received_message = await rx_client.request_response(Payload(b'request text')).pipe( operators.map(lambda payload: payload.data), operators.single() From 403403f9b3192c89119d33856ef2ee05350801c3 Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 13:11:53 +0200 Subject: [PATCH 02/10] work in progress reconnect flows --- rsocket/exceptions.py | 10 ++- rsocket/frame.py | 8 +- rsocket/helpers.py | 18 +++- rsocket/request_handler.py | 10 +-- rsocket/rsocket_base.py | 115 ++++++++++++++++-------- rsocket/rsocket_client.py | 33 +++++-- rsocket/rsocket_server.py | 15 +++- rsocket/streams/stream_handler.py | 4 +- rsocket/transports/aiohttp_websocket.py | 26 +++--- rsocket/transports/tcp.py | 19 ++-- tests/conftest.py | 12 +-- tests/rsocket/helpers.py | 31 ++++++- tests/rsocket/test_connection_lost.py | 82 +++++++++++++++++ tests/rsocket/test_frame.py | 4 +- tests/rsocket/test_lease.py | 16 ++-- tests/rsocket/test_load_balancer.py | 20 +---- tests/rsocket/test_request_response.py | 6 +- tests/rsocket/test_request_stream.py | 4 +- tests/rsocket/test_rsocket.py | 4 +- 19 files changed, 304 insertions(+), 133 deletions(-) create mode 100644 tests/rsocket/test_connection_lost.py diff --git a/rsocket/exceptions.py b/rsocket/exceptions.py index d6a033a0..006e9561 100644 --- a/rsocket/exceptions.py +++ b/rsocket/exceptions.py @@ -34,11 +34,11 @@ class RSocketStreamAllocationFailure(RSocketError): pass -class RSocketValueErrorException(RSocketError): +class RSocketValueError(RSocketError): pass -class RSocketProtocolException(RSocketError): +class RSocketProtocolError(RSocketError): def __init__(self, error_code: ErrorCode, data: Optional[str] = None): self.error_code = error_code self.data = data @@ -47,7 +47,7 @@ def __str__(self) -> str: return 'RSocket error %s(%s): "%s"' % (self.error_code.name, self.error_code.value, self.data or '') -class RSocketStreamIdInUse(RSocketProtocolException): +class RSocketStreamIdInUse(RSocketProtocolError): def __init__(self, stream_id: int): super().__init__(ErrorCode.REJECTED) @@ -56,3 +56,7 @@ def __init__(self, stream_id: int): class RSocketFrameFragmentDifferentType(RSocketError): pass + + +class RSocketTransportError(RSocketError): + pass diff --git a/rsocket/frame.py b/rsocket/frame.py index 45fae209..72695db4 100644 --- a/rsocket/frame.py +++ b/rsocket/frame.py @@ -5,7 +5,7 @@ from typing import Tuple, Optional from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException, ParseError, RSocketUnknownFrameType +from rsocket.exceptions import RSocketProtocolError, ParseError, RSocketUnknownFrameType from rsocket.frame_helpers import is_flag_set, unpack_position, pack_position, unpack_24bit, pack_24bit, unpack_32bit, \ ensure_bytes @@ -626,7 +626,7 @@ def parse_or_ignore(buffer: bytes) -> Optional[Frame]: return frame except Exception as exception: if not header.flags_ignore: - raise RSocketProtocolException(ErrorCode.CONNECTION_ERROR, str(exception)) from exception + raise RSocketProtocolError(ErrorCode.CONNECTION_ERROR, str(exception)) from exception def is_fragmentable_frame(frame: Frame) -> bool: @@ -643,7 +643,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame frame = ErrorFrame() frame.stream_id = stream_id - if isinstance(exception, RSocketProtocolException): + if isinstance(exception, RSocketProtocolError): frame.error_code = exception.error_code frame.data = ensure_bytes(exception.data) else: @@ -655,7 +655,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame def error_frame_to_exception(frame: ErrorFrame) -> Exception: if frame.error_code != ErrorCode.APPLICATION_ERROR: - return RSocketProtocolException(frame.error_code, data=frame.data.decode()) + return RSocketProtocolError(frame.error_code, data=frame.data.decode()) return RuntimeError(frame.data.decode('utf-8')) diff --git a/rsocket/helpers.py b/rsocket/helpers.py index 777029bd..7d5d655d 100644 --- a/rsocket/helpers.py +++ b/rsocket/helpers.py @@ -1,20 +1,22 @@ import asyncio -from typing import Optional +from contextlib import contextmanager +from typing import Optional, Any from reactivestreams.publisher import DefaultPublisher from reactivestreams.subscriber import Subscriber from reactivestreams.subscription import DefaultSubscription +from rsocket.exceptions import RSocketTransportError from rsocket.frame import Frame from rsocket.payload import Payload _default = object() -def create_future(payload: Optional[Payload] = _default) -> asyncio.Future: +def create_future(value: Optional[Any] = _default) -> asyncio.Future: future = asyncio.get_event_loop().create_future() - if payload is not _default: - future.set_result(payload) + if value is not _default: + future.set_result(value) return future @@ -50,3 +52,11 @@ def __eq__(self, other): def __hash__(self): return hash((self.id, self.name)) + + +@contextmanager +def wrap_transport_exception(): + try: + yield + except Exception as exception: + raise RSocketTransportError from exception diff --git a/rsocket/request_handler.py b/rsocket/request_handler.py index f4166f16..ad7da3e6 100644 --- a/rsocket/request_handler.py +++ b/rsocket/request_handler.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from asyncio import Future from datetime import timedelta -from typing import Tuple, Optional, Callable +from typing import Tuple, Optional from reactivestreams.publisher import Publisher from reactivestreams.subscriber import Subscriber @@ -62,7 +62,7 @@ async def on_keepalive_timeout(self, ... @abstractmethod - async def on_connection_lost(self, rsocket): + async def on_connection_lost(self, rsocket, exception): ... def _parse_composite_metadata(self, metadata: bytes) -> CompositeMetadata: @@ -97,10 +97,10 @@ async def request_stream(self, payload: Payload) -> Publisher: async def on_error(self, error_code: ErrorCode, payload: Payload): logger().error('Error: %s, %s', error_code, payload) - async def on_connection_lost(self, rsocket): - pass + async def on_connection_lost(self, rsocket, exception: Exception): + await rsocket.close() async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): + rsocket): pass diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index a19b9c2c..8bca69cc 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -1,6 +1,6 @@ import abc import asyncio -from asyncio import Future, QueueEmpty, Task +from asyncio import Future, QueueEmpty, Task, Event from datetime import timedelta from typing import Union, Optional, Dict, Any, Coroutine, Callable, Type, cast, TypeVar @@ -8,7 +8,7 @@ from reactivestreams.subscriber import DefaultSubscriber from rsocket.datetime_helpers import to_milliseconds from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError, RSocketTransportError from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.frame import (KeepAliveFrame, MetadataPushFrame, RequestFireAndForgetFrame, @@ -40,7 +40,6 @@ from rsocket.stream_control import StreamControl from rsocket.streams.backpressureapi import BackpressureApi from rsocket.streams.stream_handler import StreamHandler -from rsocket.transports.transport import Transport async def noop_frame_handler(frame): @@ -60,7 +59,7 @@ def on_next(self, value, is_complete=False): self._socket.send_lease(value) def __init__(self, - handler_factory: Type[RequestHandler] = BaseRequestHandler, + handler_factory: Callable[['RSocketBase'], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, @@ -82,6 +81,7 @@ def __init__(self, self._lease_publisher = lease_publisher self._sender_task = None self._receiver_task = None + self._transport_ready = Event() self._async_frame_handler_by_type: Dict[Type[Frame], Any] = { RequestResponseFrame: self.handle_request_response, @@ -102,7 +102,7 @@ def _setup_internals(self): pass @abc.abstractmethod - def _current_transport(self) -> Transport: + def _current_transport(self) -> Future: ... def _reset_internals(self): @@ -127,16 +127,22 @@ def _start_tasks(self): self._receiver_task = self._start_task_if_not_closing(self._receiver) self._sender_task = self._start_task_if_not_closing(self._sender) + def _peek_at_queue_next_item(self, queue) -> Optional[Frame]: + if len(queue._queue) > 0: + return queue._queue[0] + return None + async def connect(self): logger().debug('%s: sending setup frame', self._log_identifier()) - self.send_frame(self._create_setup_frame(self._data_encoding, - self._metadata_encoding, - self._setup_payload)) + self.send_priority_frame(self._create_setup_frame(self._data_encoding, + self._metadata_encoding, + self._setup_payload)) if self._honor_lease: self._subscribe_to_lease_publisher() + self._transport_ready.set() return self def _ensure_encoding_name(self, encoding) -> bytes: @@ -177,6 +183,15 @@ def _queue_request_frame(self, frame: RequestFrame): self._request_queue.put_nowait(frame) + def send_priority_frame(self, frame: Frame): + items = [] + while not self._send_queue.empty(): + items.append(self._send_queue.get_nowait()) + + self._send_queue.put_nowait(frame) + for item in items: + self._send_queue.put_nowait(item) + def send_frame(self, frame: Frame): self._send_queue.put_nowait(frame) @@ -251,11 +266,11 @@ async def handle_request_stream(self, frame: RequestStreamFrame): async def handle_setup(self, frame: SetupFrame): if frame.flags_resume: - raise RSocketProtocolException(ErrorCode.UNSUPPORTED_SETUP, data='Resume not supported') + raise RSocketProtocolError(ErrorCode.UNSUPPORTED_SETUP, data='Resume not supported') if frame.flags_lease: if self._lease_publisher is None: - raise RSocketProtocolException(ErrorCode.UNSUPPORTED_SETUP, data='Lease not available') + raise RSocketProtocolError(ErrorCode.UNSUPPORTED_SETUP, data='Lease not available') else: self._subscribe_to_lease_publisher() @@ -267,7 +282,7 @@ async def handle_setup(self, frame: SetupFrame): payload_from_frame(frame)) except Exception as exception: logger().error('%s: Setup error', self._log_identifier(), exc_info=True) - raise RSocketProtocolException(ErrorCode.REJECTED_SETUP, data=str(exception)) from exception + raise RSocketProtocolError(ErrorCode.REJECTED_SETUP, data=str(exception)) from exception def _subscribe_to_lease_publisher(self): if self._lease_publisher is not None: @@ -308,7 +323,7 @@ async def handle_request_channel(self, frame: RequestChannelFrame): channel_responder.frame_received(frame) async def handle_resume(self, frame: ResumeFrame): - raise RSocketProtocolException(ErrorCode.REJECTED_RESUME, data='Resume not supported') + raise RSocketProtocolError(ErrorCode.REJECTED_RESUME, data='Resume not supported') async def handle_lease(self, frame: LeaseFrame): logger().debug('%s: received lease frame', self._log_identifier()) @@ -327,12 +342,20 @@ async def handle_lease(self, frame: LeaseFrame): async def _receiver(self): try: + await self._transport_ready.wait() await self._receiver_listen() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: receiver', self._log_identifier()) - except Exception: + except RSocketTransportError as exception: + await self._on_connection_lost(exception) + except Exception as exception: logger().error('%s: Unknown error', self._log_identifier(), exc_info=True) - raise + await self._on_connection_lost(exception) + + async def _on_connection_lost(self, exception: Exception): + logger().debug(str(exception)) + self._transport_ready.clear() + await self._handler.on_connection_lost(self, exception) @abc.abstractmethod def is_server_alive(self) -> bool: @@ -340,17 +363,19 @@ def is_server_alive(self) -> bool: async def _receiver_listen(self): + transport = await self._current_transport() while self.is_server_alive(): - - next_frame_generator = await self._current_transport().next_frame_generator(self.is_server_alive()) + next_frame_generator = await transport.next_frame_generator(self.is_server_alive()) if next_frame_generator is None: break async for frame in next_frame_generator: try: await self._handle_next_frame(frame) - except RSocketProtocolException as exception: - logger().error('%s: RSocket Error %s', self._log_identifier(), str(exception)) + except RSocketProtocolError as exception: + logger().error('%s: Protocol Error %s', self._log_identifier(), str(exception)) self.send_error(frame.stream_id, exception) + except RSocketTransportError: + raise except Exception as exception: logger().error('%s: Unknown Error', self._log_identifier(), exc_info=True) self.send_error(frame.stream_id, exception) @@ -397,35 +422,52 @@ async def _finally_sender(self): pass async def _sender(self): - self._before_sender() - try: - while self.is_server_alive(): - frame = await self._send_queue.get() + await self._transport_ready.wait() + try: + transport = await self._current_transport() - await self._current_transport().send_frame(frame) - self._send_queue.task_done() + self._before_sender() + while self.is_server_alive(): + frame = await self._send_queue.get() - if self._send_queue.empty(): - await self._current_transport().on_send_queue_empty() - except ConnectionResetError as exception: - logger().debug(str(exception)) + await transport.send_frame(frame) + self._send_queue.task_done() + + if self._send_queue.empty(): + await transport.on_send_queue_empty() + except RSocketTransportError as exception: + await self._on_connection_lost(exception) + + except Exception as exception: + logger().error('%s: RSocket error', self._log_identifier(), exc_info=True) + await self._on_connection_lost(exception) + finally: + await self._finally_sender() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: sender', self._log_identifier()) - except Exception: + + except Exception as exception: logger().error('%s: RSocket error', self._log_identifier(), exc_info=True) - raise - finally: - await self._finally_sender() + await self._on_connection_lost(exception) async def close(self): + logger().debug('%s: Closing', self._log_identifier()) + self._transport_ready.clear() self._is_closing = True await self._cancel_if_task_exists(self._sender_task) await self._cancel_if_task_exists(self._receiver_task) - transport = self._current_transport() - if transport is not None: - await transport.close() + if self._current_transport().done(): + logger().debug('%s: Closing transport', self._log_identifier()) + transport = self._current_transport().result() + + if transport is not None: + try: + await transport.close() + except Exception as exception: + logger().debug('Transport already closed or failed to close: %s', str(exception)) + pass async def _cancel_if_task_exists(self, task): if task is not None: @@ -433,7 +475,7 @@ async def _cancel_if_task_exists(self, task): try: await task except asyncio.CancelledError: - logger().debug('%s: Asyncio task canceled', self._log_identifier()) + logger().debug('%s: Asyncio task canceled: %s', self._log_identifier(), str(task)) async def __aenter__(self) -> 'RSocketBase': return self @@ -466,7 +508,6 @@ def request_channel( self, payload: Payload, local_publisher: Optional[Publisher] = None) -> Union[BackpressureApi, Publisher]: - logger().debug('%s: sending request-channel: %s', self._log_identifier(), payload) requester = RequestChannelRequester(self, payload, local_publisher) diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index 7cf927df..efa3f4cb 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -1,10 +1,12 @@ import asyncio +from asyncio import Future from datetime import timedelta, datetime -from typing import Optional, Type, Callable, Coroutine +from typing import Optional, Callable, Coroutine from typing import Union from reactivestreams.publisher import Publisher from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import create_future from rsocket.logger import logger from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler @@ -17,7 +19,7 @@ class RSocketClient(RSocketBase): def __init__(self, transport_provider: Callable[[], Coroutine[None, None, Transport]], - handler_factory: Type[RequestHandler] = BaseRequestHandler, + handler_factory: Callable[[RSocketBase], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, @@ -31,6 +33,7 @@ def __init__(self, self._is_server_alive = True self._update_last_keepalive() self._transport: Optional[Transport] = None + self._next_transport = asyncio.Future() super().__init__(handler_factory=handler_factory, honor_lease=honor_lease, @@ -42,21 +45,33 @@ def __init__(self, max_lifetime_period=max_lifetime_period, setup_payload=setup_payload) - def _current_transport(self) -> Transport: - return self._transport + def _current_transport(self) -> Future: + return self._next_transport def _log_identifier(self) -> str: return 'client' async def connect(self): - await self.close() + logger().debug('%s: connecting', self._log_identifier()) + self._transport_ready.clear() + + if self._current_transport().done(): + await self.close() + self._is_closing = False - self._transport = await self._transport_provider() - await self._transport.connect() self._reset_internals() self._start_tasks() + + self._next_transport.set_result(await self._transport_provider()) + transport = await self._current_transport() + await transport.connect() + return await super().connect() + async def close(self): + await super().close() + self._next_transport = create_future() + async def __aenter__(self) -> 'RSocketClient': await self.connect() return self @@ -107,3 +122,7 @@ async def _receiver_listen(self): await super()._receiver_listen() finally: await self._cancel_if_task_exists(keepalive_timeout_task) + + async def _sender(self): + return await super()._sender() + diff --git a/rsocket/rsocket_server.py b/rsocket/rsocket_server.py index 2709687c..a5860cef 100644 --- a/rsocket/rsocket_server.py +++ b/rsocket/rsocket_server.py @@ -1,8 +1,11 @@ +import asyncio +from asyncio import Future from datetime import timedelta -from typing import Type, Optional, Union +from typing import Optional, Union, Callable from reactivestreams.publisher import Publisher from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import RequestHandler, BaseRequestHandler from rsocket.rsocket_base import RSocketBase @@ -13,7 +16,7 @@ class RSocketServer(RSocketBase): def __init__(self, transport: Transport, - handler_factory: Type[RequestHandler] = BaseRequestHandler, + handler_factory: Callable[[RSocketBase], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, @@ -32,9 +35,10 @@ def __init__(self, max_lifetime_period, setup_payload) self._transport = transport + self._transport_ready.set() - def _current_transport(self) -> Transport: - return self._transport + def _current_transport(self) -> Future: + return create_future(self._transport) def _setup_internals(self): self._reset_internals() @@ -48,3 +52,6 @@ def _get_first_stream_id(self) -> int: def is_server_alive(self) -> bool: return True + + async def _sender(self): + return await super()._sender() diff --git a/rsocket/streams/stream_handler.py b/rsocket/streams/stream_handler.py index 98aa69f9..8063542a 100644 --- a/rsocket/streams/stream_handler.py +++ b/rsocket/streams/stream_handler.py @@ -1,7 +1,7 @@ from abc import abstractmethod, ABCMeta from typing import Optional -from rsocket.exceptions import RSocketValueErrorException +from rsocket.exceptions import RSocketValueError from rsocket.frame import Frame, MAX_REQUEST_N from rsocket.frame_builders import to_cancel_frame, to_request_n_frame from rsocket.logger import logger @@ -22,7 +22,7 @@ def setup(self): def initial_request_n(self, n: int): if n <= 0: self.socket.finish_stream(self.stream_id) - raise RSocketValueErrorException('Initial request N must be > 0') + raise RSocketValueError('Initial request N must be > 0') self._initial_request_n = n return self diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index 7507f3e8..2e2b1206 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -5,7 +5,7 @@ from aiohttp import web from rsocket.frame import Frame -from rsocket.logger import logger +from rsocket.helpers import wrap_transport_exception from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.abstract_websocket import AbstractWebsocketTransport @@ -49,16 +49,15 @@ async def connect(self): self._message_handler = asyncio.create_task(self.handle_incoming_ws_messages()) async def handle_incoming_ws_messages(self): - try: + with wrap_transport_exception(): async for msg in self._ws: if msg.type == aiohttp.WSMsgType.BINARY: async for frame in self._frame_parser.receive_data(msg.data, 0): self._incoming_frame_queue.put_nowait(frame) - except asyncio.CancelledError: - logger().debug('Asyncio task canceled: aiohttp_handle_incoming_ws_messages') async def send_frame(self, frame: Frame): - await self._ws.send_bytes(frame.serialize()) + with wrap_transport_exception(): + await self._ws.send_bytes(frame.serialize()) async def close(self): await self._ws_context.__aexit__(None, None, None) @@ -75,17 +74,20 @@ def __init__(self, websocket): async def connect(self): pass - async def handle_incoming_ws_messages(self): - try: + async def _message_generator(self): + with wrap_transport_exception(): async for msg in self._ws: if msg.type == aiohttp.WSMsgType.BINARY: - async for frame in self._frame_parser.receive_data(msg.data, 0): - self._incoming_frame_queue.put_nowait(frame) - except asyncio.CancelledError: - logger().debug('Asyncio task canceled: aiohttp_handle_incoming_ws_messages') + yield msg.data + + async def handle_incoming_ws_messages(self): + async for message in self._message_generator(): + async for frame in self._frame_parser.receive_data(message, 0): + self._incoming_frame_queue.put_nowait(frame) async def send_frame(self, frame: Frame): - await self._ws.send_bytes(frame.serialize()) + with wrap_transport_exception(): + await self._ws.send_bytes(frame.serialize()) async def close(self): await self._ws.close() diff --git a/rsocket/transports/tcp.py b/rsocket/transports/tcp.py index 17ff5697..aed9d8b6 100644 --- a/rsocket/transports/tcp.py +++ b/rsocket/transports/tcp.py @@ -1,7 +1,7 @@ from asyncio import StreamReader, StreamWriter from rsocket.frame import Frame, serialize_with_frame_size_header -from rsocket.logger import logger +from rsocket.helpers import wrap_transport_exception from rsocket.transports.transport import Transport @@ -15,24 +15,23 @@ async def connect(self): pass async def send_frame(self, frame: Frame): - self._writer.write(serialize_with_frame_size_header(frame)) + with wrap_transport_exception(): + self._writer.write(serialize_with_frame_size_header(frame)) async def on_send_queue_empty(self): - await self._writer.drain() + with wrap_transport_exception(): + await self._writer.drain() async def close(self): self._writer.close() await self._writer.wait_closed() async def next_frame_generator(self, is_server_alive): - try: + with wrap_transport_exception(): data = await self._reader.read(1024) - except (ConnectionResetError, BrokenPipeError) as exception: - logger().debug(str(exception)) - return # todo: workaround to silence errors on client closing. this needs a better solution. - if not data: - self._writer.close() - return + if not data: + self._writer.close() + return return self._frame_parser.receive_data(data) diff --git a/tests/conftest.py b/tests/conftest.py index b715294a..dd80f4d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,13 +11,13 @@ from quart import Quart from rsocket.frame_parser import FrameParser -from rsocket.logger import logger from rsocket.rsocket_base import RSocketBase from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.aiohttp_websocket import websocket_client, websocket_handler_factory from rsocket.transports.quart_websocket import websocket_handler from rsocket.transports.tcp import TransportTCP +from tests.rsocket.helpers import assert_no_open_streams logging.basicConfig(level=logging.DEBUG) @@ -143,16 +143,6 @@ async def server_provider(): await finish() -def assert_no_open_streams(client: RSocketBase, server: RSocketBase): - logger().info('Checking for open client streams') - - assert len(client._stream_control._streams) == 0, 'Client has open streams' - - logger().info('Checking for open server streams') - - assert len(server._stream_control._streams) == 0, 'Server has open streams' - - @pytest.fixture def connection(): return FrameParser() diff --git a/tests/rsocket/helpers.py b/tests/rsocket/helpers.py index a50dda96..4a76cec2 100644 --- a/tests/rsocket/helpers.py +++ b/tests/rsocket/helpers.py @@ -1,7 +1,11 @@ from math import ceil +from typing import Type from rsocket.helpers import create_future +from rsocket.logger import logger from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_base import RSocketBase def data_bits(data: bytes, name: str = None): @@ -24,6 +28,31 @@ def bits(bit_count, value, comment) -> str: return f'{value:b}'.zfill(bit_count) -def future_from_request(request: Payload): +def future_from_payload(request: Payload): return create_future(Payload(b'data: ' + request.data, b'meta: ' + request.metadata)) + + +def assert_no_open_streams(client: RSocketBase, server: RSocketBase): + logger().info('Checking for open client streams') + + assert len(client._stream_control._streams) == 0, 'Client has open streams' + + logger().info('Checking for open server streams') + + assert len(server._stream_control._streams) == 0, 'Server has open streams' + + +class IdentifiedHandler(BaseRequestHandler): + def __init__(self, socket, server_id: int): + super().__init__(socket) + self._server_id = server_id + + +class IdentifiedHandlerFactory: + def __init__(self, server_id: int, handler_factory: Type[IdentifiedHandler]): + self._server_id = server_id + self._handler_factory = handler_factory + + def factory(self, socket) -> BaseRequestHandler: + return self._handler_factory(socket, self._server_id) diff --git a/tests/rsocket/test_connection_lost.py b/tests/rsocket/test_connection_lost.py new file mode 100644 index 00000000..9e79d2bd --- /dev/null +++ b/tests/rsocket/test_connection_lost.py @@ -0,0 +1,82 @@ +import asyncio +from asyncio import Event, Future +from asyncio.base_events import Server +from typing import Optional, Tuple + +from rsocket.logger import logger +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.transports.tcp import TransportTCP +from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, \ + IdentifiedHandler + + +class ServerHandler(IdentifiedHandler): + async def request_response(self, payload: Payload) -> Future: + return future_from_payload(Payload(payload.data + (' server %d' % self._server_id).encode(), payload.metadata)) + + +async def test_connection_lost(unused_tcp_port): + index_iterator = iter(range(1, 3)) + + wait_for_server = Event() + server_connection: Tuple = None + client_connection: Tuple = None + + class ClientHandler(BaseRequestHandler): + async def on_connection_lost(self, rsocket, exception: Exception): + logger().info('Reconnecting') + await rsocket.connect() + + def session(*connection): + nonlocal server, server_connection + server_connection = connection + server = RSocketServer(TransportTCP(*connection), + IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory) + wait_for_server.set() + + async def start(): + nonlocal service, client + service = await asyncio.start_server(session, host, port) + + async def transport_provider(): + try: + nonlocal client_connection + client_connection = await asyncio.open_connection(host, port) + return TransportTCP(*client_connection) + except Exception: + logger().error('Client connection error', exc_info=True) + raise + + client = RSocketClient(transport_provider, handler_factory=ClientHandler) + + service: Optional[Server] = None + server: Optional[RSocketServer] = None + client: Optional[RSocketClient] = None + port = unused_tcp_port + host = 'localhost' + + await start() + + try: + async with client as connection: + await wait_for_server.wait() + wait_for_server.clear() + response1 = await connection.request_response(Payload(b'request 1')) + force_closing_connection(server_connection) + await server.close() # cleanup async tasks from previous server to avoid errors (?) + await wait_for_server.wait() + response2 = await connection.request_response(Payload(b'request 2')) + + assert response1.data == b'data: request 1 server 1' + assert response2.data == b'data: request 2 server 2' + finally: + await server.close() + + service.close() + + +def force_closing_connection(current_connection): + current_connection[1].close() diff --git a/tests/rsocket/test_frame.py b/tests/rsocket/test_frame.py index 01075694..8c290508 100644 --- a/tests/rsocket/test_frame.py +++ b/tests/rsocket/test_frame.py @@ -2,7 +2,7 @@ import pytest from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.extensions.authentication_types import WellKnownAuthenticationTypes from rsocket.extensions.composite_metadata import CompositeMetadata from rsocket.extensions.mimetypes import WellKnownMimeTypes @@ -598,5 +598,5 @@ def test_parse_broken_frame_raises_exception(): bits(13, 23, 'Number of frames to request - broken. smaller than 31 bits'), ) - with pytest.raises(RSocketProtocolException): + with pytest.raises(RSocketProtocolError): parse_or_ignore(broken_frame_data) diff --git a/tests/rsocket/test_lease.py b/tests/rsocket/test_lease.py index 9bde2892..2e02bc76 100644 --- a/tests/rsocket/test_lease.py +++ b/tests/rsocket/test_lease.py @@ -4,11 +4,11 @@ import pytest from reactivestreams.subscriber import Subscriber -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.lease import SingleLeasePublisher, DefinedLease from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload class PeriodicalLeasePublisher(SingleLeasePublisher): @@ -44,7 +44,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def test_request_response_with_server_side_lease_works(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -59,7 +59,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_client_and_server_side_lease_works(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with PeriodicalLeasePublisher( maximum_request_count=2, @@ -88,7 +88,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_lease_too_many_requests(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -106,7 +106,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_lease_client_side_exception_requests_late(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -134,7 +134,7 @@ async def test_server_rejects_all_requests_if_lease_not_supported(lazy_pipe): async def test_request_response_with_lease_server_side_exception(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -145,5 +145,5 @@ async def request_response(self, request: Payload): response = await client.request_response(Payload(b'dog', b'cat')) assert response == Payload(b'data: dog', b'meta: cat') - with pytest.raises(RSocketProtocolException): + with pytest.raises(RSocketProtocolError): await client.request_response(Payload(b'invalid request')) diff --git a/tests/rsocket/test_load_balancer.py b/tests/rsocket/test_load_balancer.py index 91e5667c..f1829ffb 100644 --- a/tests/rsocket/test_load_balancer.py +++ b/tests/rsocket/test_load_balancer.py @@ -4,26 +4,14 @@ from rsocket.load_balancer.load_balancer_rsocket import LoadBalancerRSocket from rsocket.load_balancer.round_robin import LoadBalancerRoundRobin from rsocket.payload import Payload -from rsocket.request_handler import BaseRequestHandler from tests.conftest import pipe_factory_tcp -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, IdentifiedHandler -class HandlerFactory(): - def __init__(self, server_id: int): - self._server_id = server_id - - def factory(self, socket): - return Handler(socket, self._server_id) - - -class Handler(BaseRequestHandler): - def __init__(self, socket, server_id: int): - super().__init__(socket) - self.server_id = server_id +class Handler(IdentifiedHandler): async def request_response(self, request: Payload): - return future_from_request(Payload(request.data + (' server %d' % self.server_id).encode(), request.metadata)) + return future_from_payload(Payload(request.data + (' server %d' % self._server_id).encode(), request.metadata)) async def test_load_balancer_round_robin(unused_tcp_port_factory): @@ -36,7 +24,7 @@ async def test_load_balancer_round_robin(unused_tcp_port_factory): tcp_port = unused_tcp_port_factory() _, client = await stack.enter_async_context( pipe_factory_tcp(tcp_port, - server_arguments={'handler_factory': HandlerFactory(i).factory}, + server_arguments={'handler_factory': IdentifiedHandlerFactory(i, Handler).factory}, auto_connect_client=False)) clients.append(client) diff --git a/tests/rsocket/test_request_response.py b/tests/rsocket/test_request_response.py index 74ec7d86..6db97ba0 100644 --- a/tests/rsocket/test_request_response.py +++ b/tests/rsocket/test_request_response.py @@ -7,13 +7,13 @@ from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload async def test_request_response_awaitable_wrapper(pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) server, client = pipe server._handler = Handler(server) @@ -25,7 +25,7 @@ async def request_response(self, request: Payload): async def test_request_response_repeated(pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) server, client = pipe server._handler = Handler(server) diff --git a/tests/rsocket/test_request_stream.py b/tests/rsocket/test_request_stream.py index 78aa7028..d1ddd21c 100644 --- a/tests/rsocket/test_request_stream.py +++ b/tests/rsocket/test_request_stream.py @@ -8,7 +8,7 @@ from reactivestreams.subscriber import DefaultSubscriber, Subscriber from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket from rsocket.awaitable.collector_subscriber import CollectorSubscriber -from rsocket.exceptions import RSocketValueErrorException +from rsocket.exceptions import RSocketValueError from rsocket.frame_helpers import ensure_bytes from rsocket.helpers import DefaultPublisherSubscription from rsocket.payload import Payload @@ -52,7 +52,7 @@ async def test_request_stream_prevent_negative_initial_request_n(pipe: Tuple[RSo initial_request_n): server, client = pipe - with pytest.raises(RSocketValueErrorException): + with pytest.raises(RSocketValueError): client.request_stream(Payload()).initial_request_n(initial_request_n) diff --git a/tests/rsocket/test_rsocket.py b/tests/rsocket/test_rsocket.py index 9932fbf3..a0140667 100644 --- a/tests/rsocket/test_rsocket.py +++ b/tests/rsocket/test_rsocket.py @@ -5,7 +5,7 @@ import pytest from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler @@ -55,7 +55,7 @@ async def on_keepalive_timeout(self, 'max_lifetime_period': timedelta(seconds=1), 'handler_factory': ClientHandler}, server_arguments={'handler_factory': Handler}) as (server, client): - with pytest.raises(RSocketProtocolException) as exc_info: + with pytest.raises(RSocketProtocolError) as exc_info: await client.request_response(Payload(b'dog', b'cat')) assert exc_info.value.data == 'Server not alive' From 640aa616995ce116ac365d37cb2b4f7d718f2e0d Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 13:13:57 +0200 Subject: [PATCH 03/10] fix tests with wrong assumptions about rsocket internals --- tests/rsocket/misbehaving_rsocket.py | 6 +++--- tests/rsocket/test_resume_unsupported.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/rsocket/misbehaving_rsocket.py b/tests/rsocket/misbehaving_rsocket.py index 95e43458..3892969c 100644 --- a/tests/rsocket/misbehaving_rsocket.py +++ b/tests/rsocket/misbehaving_rsocket.py @@ -3,11 +3,11 @@ class MisbehavingRSocket: - def __init__(self, socket: Transport): - self._socket = socket + def __init__(self, transport: Transport): + self._transport = transport async def send_frame(self, frame: Frame): - await self._socket.send_frame(frame) + await self._transport.send_frame(frame) class BrokenFrame: diff --git a/tests/rsocket/test_resume_unsupported.py b/tests/rsocket/test_resume_unsupported.py index 366848c1..13f99bdc 100644 --- a/tests/rsocket/test_resume_unsupported.py +++ b/tests/rsocket/test_resume_unsupported.py @@ -27,7 +27,8 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): client.set_handler_factory(Handler) await client.connect() - bad_client = MisbehavingRSocket(client._transport) + transport = await client._current_transport() + bad_client = MisbehavingRSocket(transport) setup = SetupFrame() setup.flags_lease = False @@ -63,7 +64,8 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): client.set_handler_using_factory(Handler) - bad_client = MisbehavingRSocket(client._transport) + transport = await client._current_transport() + bad_client = MisbehavingRSocket(transport) resume = ResumeFrame() resume.token_length = 1 From 3b3f0449bbefcd319422dc5eec116dd7d6c69c08 Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 13:15:20 +0200 Subject: [PATCH 04/10] fix tests with wrong assumptions about rsocket internals --- tests/rsocket/test_resume_unsupported.py | 31 ++++++++++++------------ 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/rsocket/test_resume_unsupported.py b/tests/rsocket/test_resume_unsupported.py index 13f99bdc..a57f7013 100644 --- a/tests/rsocket/test_resume_unsupported.py +++ b/tests/rsocket/test_resume_unsupported.py @@ -26,27 +26,26 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): error_received.set() client.set_handler_factory(Handler) - await client.connect() - transport = await client._current_transport() - bad_client = MisbehavingRSocket(transport) - setup = SetupFrame() - setup.flags_lease = False - setup.flags_resume = True - setup.token_length = 1 - setup.resume_identification_token = b'a' - setup.keep_alive_milliseconds = 123 - setup.max_lifetime_milliseconds = 456 - setup.data_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() - setup.metadata_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() + async with client as connected_client: + transport = await connected_client._current_transport() + bad_client = MisbehavingRSocket(transport) - await bad_client.send_frame(setup) + setup = SetupFrame() + setup.flags_lease = False + setup.flags_resume = True + setup.token_length = 1 + setup.resume_identification_token = b'a' + setup.keep_alive_milliseconds = 123 + setup.max_lifetime_milliseconds = 456 + setup.data_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() + setup.metadata_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() - await error_received.wait() + await bad_client.send_frame(setup) - await client.close() + await error_received.wait() - assert received_error_code == ErrorCode.UNSUPPORTED_SETUP + assert received_error_code == ErrorCode.UNSUPPORTED_SETUP @pytest.mark.allow_error_log From 2188739e7de0204aeeb90ee9f1bac96ead38d549 Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:13:06 +0200 Subject: [PATCH 05/10] fix examples support reconnection attempts made the transport provider an async generator instead of a single callback --- examples/client.py | 6 +- examples/client_springboot.py | 9 +- examples/client_with_routing.py | 9 +- examples/run_against_example_java_server.py | 7 +- rsocket/exceptions.py | 4 + rsocket/rsocket_base.py | 13 +-- rsocket/rsocket_client.py | 32 +++++-- rsocket/transports/aiohttp_websocket.py | 4 +- tests/conftest.py | 4 +- tests/rsocket/helpers.py | 4 + tests/rsocket/test_connection_lost.py | 93 +++++++++++++++++++-- 11 files changed, 153 insertions(+), 32 deletions(-) diff --git a/examples/client.py b/examples/client.py index 31d87a0b..e4ec07bd 100644 --- a/examples/client.py +++ b/examples/client.py @@ -15,9 +15,11 @@ def on_next(self, value, is_complete=False): async def main(): - connection = await asyncio.open_connection('localhost', 6565) + async def transport_provider(): + connection = await asyncio.open_connection('localhost', 7000) + yield TransportTCP(*connection) - async with RSocketClient(TransportTCP(*connection)) as client: + async with RSocketClient(transport_provider()) as client: payload = Payload(b'%Y-%m-%d %H:%M:%S') async def run_request_response(): diff --git a/examples/client_springboot.py b/examples/client_springboot.py index c009f756..6ede9b12 100644 --- a/examples/client_springboot.py +++ b/examples/client_springboot.py @@ -4,9 +4,9 @@ from uuid import uuid4 from reactivestreams.subscriber import DefaultSubscriber +from rsocket.extensions.helpers import composite, route, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.payload import Payload -from rsocket.extensions.helpers import composite, route, authenticate_simple from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -19,12 +19,15 @@ def on_next(self, value, is_complete=False): async def main(): - connection = await asyncio.open_connection('localhost', 7000) + async def transport_provider(): + connection = await asyncio.open_connection('localhost', 7000) + yield TransportTCP(*connection) setup_payload = Payload( data=str(uuid4()).encode(), metadata=composite(route('shell-client'), authenticate_simple('user', 'pass'))) - async with RSocketClient(TransportTCP(*connection), + + async with RSocketClient(transport_provider(), setup_payload=setup_payload, metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA): await asyncio.sleep(5) diff --git a/examples/client_with_routing.py b/examples/client_with_routing.py index ce627084..dfb70c45 100644 --- a/examples/client_with_routing.py +++ b/examples/client_with_routing.py @@ -5,10 +5,10 @@ from reactivestreams.subscriber import Subscriber from reactivestreams.subscription import Subscription +from rsocket.extensions.helpers import route, composite, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.fragment import Fragment from rsocket.payload import Payload -from rsocket.extensions.helpers import route, composite, authenticate_simple from rsocket.rsocket_client import RSocketClient from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator from rsocket.transports.tcp import TransportTCP @@ -143,9 +143,12 @@ async def request_fragmented_stream(socket: RSocketClient): async def main(): - connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(TransportTCP(*connection), + async def transport_provider(): + connection = await asyncio.open_connection('localhost', 6565) + yield TransportTCP(*connection) + + async with RSocketClient(transport_provider(), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client: await request_response(client) await request_stream(client) diff --git a/examples/run_against_example_java_server.py b/examples/run_against_example_java_server.py index b25f0420..6dd4c0c9 100644 --- a/examples/run_against_example_java_server.py +++ b/examples/run_against_example_java_server.py @@ -33,8 +33,11 @@ def on_complete(self): def on_error(self, exception: Exception): completion_event.set() - connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(TransportTCP(*connection), + async def transport_provider(): + connection = await asyncio.open_connection('localhost', 6565) + yield TransportTCP(*connection) + + async with RSocketClient(transport_provider(), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA.value.name, data_encoding=WellKnownMimeTypes.APPLICATION_JSON.value.name) as client: metadata = CompositeMetadata() diff --git a/rsocket/exceptions.py b/rsocket/exceptions.py index 006e9561..22e702f1 100644 --- a/rsocket/exceptions.py +++ b/rsocket/exceptions.py @@ -60,3 +60,7 @@ class RSocketFrameFragmentDifferentType(RSocketError): class RSocketTransportError(RSocketError): pass + + +class RSocketNoAvailableTransport(RSocketError): + pass diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index 8bca69cc..74858a85 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -446,7 +446,8 @@ async def _sender(self): await self._finally_sender() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: sender', self._log_identifier()) - + except RSocketTransportError as exception: + await self._on_connection_lost(exception) except Exception as exception: logger().error('%s: RSocket error', self._log_identifier(), exc_info=True) await self._on_connection_lost(exception) @@ -484,21 +485,21 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() def request_response(self, payload: Payload) -> Future: - logger().debug('%s: sending request-response: %s', self._log_identifier(), payload) + logger().debug('%s: request-response: %s', self._log_identifier(), payload) requester = RequestResponseRequester(self, payload) self.register_new_stream(requester).setup() return requester.run() def fire_and_forget(self, payload: Payload): - logger().debug('%s: sending fire-and-forget: %s', self._log_identifier(), payload) + logger().debug('%s: fire-and-forget: %s', self._log_identifier(), payload) stream_id = self._allocate_stream() self.send_request(to_fire_and_forget_frame(stream_id, payload)) self.finish_stream(stream_id) def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: - logger().debug('%s: sending request-stream: %s', self._log_identifier(), payload) + logger().debug('%s: request-stream: %s', self._log_identifier(), payload) requester = RequestStreamRequester(self, payload) self.register_new_stream(requester) @@ -508,13 +509,15 @@ def request_channel( self, payload: Payload, local_publisher: Optional[Publisher] = None) -> Union[BackpressureApi, Publisher]: - logger().debug('%s: sending request-channel: %s', self._log_identifier(), payload) + logger().debug('%s: request-channel: %s', self._log_identifier(), payload) requester = RequestChannelRequester(self, payload, local_publisher) self.register_new_stream(requester) return requester def metadata_push(self, metadata: bytes): + logger().debug('%s: metadata-push: %s', self._log_identifier(), metadata) + frame = MetadataPushFrame() frame.metadata = metadata self.send_frame(frame) diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index efa3f4cb..1ebbcc5c 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -1,10 +1,11 @@ import asyncio from asyncio import Future from datetime import timedelta, datetime -from typing import Optional, Callable, Coroutine +from typing import Optional, Callable, AsyncGenerator, Any from typing import Union from reactivestreams.publisher import Publisher +from rsocket.exceptions import RSocketNoAvailableTransport, RSocketTransportError from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.helpers import create_future from rsocket.logger import logger @@ -18,7 +19,7 @@ class RSocketClient(RSocketBase): def __init__(self, - transport_provider: Callable[[], Coroutine[None, None, Transport]], + transport_provider: AsyncGenerator[Transport, Any], handler_factory: Callable[[RSocketBase], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, @@ -29,7 +30,7 @@ def __init__(self, max_lifetime_period: timedelta = timedelta(minutes=10), setup_payload: Optional[Payload] = None, ): - self._transport_provider = transport_provider + self._transport_provider = transport_provider.__aiter__() self._is_server_alive = True self._update_last_keepalive() self._transport: Optional[Transport] = None @@ -62,12 +63,32 @@ async def connect(self): self._reset_internals() self._start_tasks() - self._next_transport.set_result(await self._transport_provider()) + try: + new_transport = await self._get_new_transport() + except Exception as exception: + logger().error('%s: Connection error', self._log_identifier(), exc_info=True) + self._handler.on_connection_lost(self, exception) + return + + if new_transport is None: + raise RSocketNoAvailableTransport() + + self._next_transport.set_result(new_transport) transport = await self._current_transport() - await transport.connect() + + try: + await transport.connect() + except Exception as exception: + raise RSocketTransportError from exception return await super().connect() + async def _get_new_transport(self): + try: + return await self._transport_provider.__anext__() + except StopAsyncIteration: + return + async def close(self): await super().close() self._next_transport = create_future() @@ -125,4 +146,3 @@ async def _receiver_listen(self): async def _sender(self): return await super()._sender() - diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index 2e2b1206..ebdb4b8e 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -14,9 +14,9 @@ @asynccontextmanager async def websocket_client(url, *args, **kwargs) -> RSocketClient: async def transport_provider(): - return TransportAioHttpClient(url) + yield TransportAioHttpClient(url) - async with RSocketClient(transport_provider, *args, **kwargs) as client: + async with RSocketClient(transport_provider(), *args, **kwargs) as client: yield client diff --git a/tests/conftest.py b/tests/conftest.py index dd80f4d7..544334b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,9 +104,9 @@ async def start(): async def transport_provider(): connection = await asyncio.open_connection(host, port) - return TransportTCP(*connection) + yield TransportTCP(*connection) - client = RSocketClient(transport_provider, **(client_arguments or {})) + client = RSocketClient(transport_provider(), **(client_arguments or {})) if auto_connect_client: await client.connect() diff --git a/tests/rsocket/helpers.py b/tests/rsocket/helpers.py index 4a76cec2..458b9f84 100644 --- a/tests/rsocket/helpers.py +++ b/tests/rsocket/helpers.py @@ -56,3 +56,7 @@ def __init__(self, server_id: int, handler_factory: Type[IdentifiedHandler]): def factory(self, socket) -> BaseRequestHandler: return self._handler_factory(socket, self._server_id) + + +def force_closing_connection(current_connection): + current_connection[1].close() diff --git a/tests/rsocket/test_connection_lost.py b/tests/rsocket/test_connection_lost.py index 9e79d2bd..2ff8d2d6 100644 --- a/tests/rsocket/test_connection_lost.py +++ b/tests/rsocket/test_connection_lost.py @@ -3,14 +3,16 @@ from asyncio.base_events import Server from typing import Optional, Tuple +from rsocket.frame import Frame from rsocket.logger import logger from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.tcp import TransportTCP +from rsocket.transports.transport import Transport from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, \ - IdentifiedHandler + IdentifiedHandler, force_closing_connection class ServerHandler(IdentifiedHandler): @@ -25,6 +27,82 @@ async def test_connection_lost(unused_tcp_port): server_connection: Tuple = None client_connection: Tuple = None + class ClientHandler(BaseRequestHandler): + async def on_connection_lost(self, rsocket, exception: Exception): + logger().info('Reconnecting') + await rsocket.connect() + + def session(*connection): + nonlocal server, server_connection + server_connection = connection + server = RSocketServer(TransportTCP(*connection), + IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory) + wait_for_server.set() + + async def start(): + nonlocal service, client + service = await asyncio.start_server(session, host, port) + + async def transport_provider(): + while True: + try: + nonlocal client_connection + client_connection = await asyncio.open_connection(host, port) + yield TransportTCP(*client_connection) + except Exception: + logger().error('Client connection error', exc_info=True) + raise + + client = RSocketClient(transport_provider(), handler_factory=ClientHandler) + + service: Optional[Server] = None + server: Optional[RSocketServer] = None + client: Optional[RSocketClient] = None + port = unused_tcp_port + host = 'localhost' + + await start() + + try: + async with client as connection: + await wait_for_server.wait() + wait_for_server.clear() + response1 = await connection.request_response(Payload(b'request 1')) + force_closing_connection(server_connection) + await server.close() # cleanup async tasks from previous server to avoid errors (?) + await wait_for_server.wait() + response2 = await connection.request_response(Payload(b'request 2')) + + assert response1.data == b'data: request 1 server 1' + assert response2.data == b'data: request 2 server 2' + finally: + await server.close() + + service.close() + + +class FailingTransportTCP(Transport): + + async def connect(self): + raise Exception + + async def send_frame(self, frame: Frame): + pass + + async def next_frame_generator(self, is_server_alive): + pass + + async def close(self): + pass + + +async def test_connection_failure(unused_tcp_port): + index_iterator = iter(range(1, 3)) + + wait_for_server = Event() + server_connection: Tuple = None + client_connection: Tuple = None + class ClientHandler(BaseRequestHandler): async def on_connection_lost(self, rsocket, exception: Exception): logger().info('Reconnecting') @@ -45,12 +123,17 @@ async def transport_provider(): try: nonlocal client_connection client_connection = await asyncio.open_connection(host, port) - return TransportTCP(*client_connection) + yield TransportTCP(*client_connection) + + yield FailingTransportTCP() + + client_connection = await asyncio.open_connection(host, port) + yield TransportTCP(*client_connection) except Exception: logger().error('Client connection error', exc_info=True) raise - client = RSocketClient(transport_provider, handler_factory=ClientHandler) + client = RSocketClient(transport_provider(), handler_factory=ClientHandler) service: Optional[Server] = None server: Optional[RSocketServer] = None @@ -76,7 +159,3 @@ async def transport_provider(): await server.close() service.close() - - -def force_closing_connection(current_connection): - current_connection[1].close() From 0c8bcdb9781808e2c393486fadf472a66696313a Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:23:23 +0200 Subject: [PATCH 06/10] examples cleanup --- examples/client.py | 7 +++---- examples/client_springboot.py | 7 +++---- examples/client_with_routing.py | 7 +++---- examples/run_against_example_java_server.py | 7 +++---- rsocket/helpers.py | 4 ++++ 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/client.py b/examples/client.py index e4ec07bd..f2e05be3 100644 --- a/examples/client.py +++ b/examples/client.py @@ -2,6 +2,7 @@ import logging from reactivestreams.subscriber import DefaultSubscriber +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -15,11 +16,9 @@ def on_next(self, value, is_complete=False): async def main(): - async def transport_provider(): - connection = await asyncio.open_connection('localhost', 7000) - yield TransportTCP(*connection) + connection = await asyncio.open_connection('localhost', 7000) - async with RSocketClient(transport_provider()) as client: + async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: payload = Payload(b'%Y-%m-%d %H:%M:%S') async def run_request_response(): diff --git a/examples/client_springboot.py b/examples/client_springboot.py index 6ede9b12..a15e762f 100644 --- a/examples/client_springboot.py +++ b/examples/client_springboot.py @@ -6,6 +6,7 @@ from reactivestreams.subscriber import DefaultSubscriber from rsocket.extensions.helpers import composite, route, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -19,15 +20,13 @@ def on_next(self, value, is_complete=False): async def main(): - async def transport_provider(): - connection = await asyncio.open_connection('localhost', 7000) - yield TransportTCP(*connection) + connection = await asyncio.open_connection('localhost', 7000) setup_payload = Payload( data=str(uuid4()).encode(), metadata=composite(route('shell-client'), authenticate_simple('user', 'pass'))) - async with RSocketClient(transport_provider(), + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), setup_payload=setup_payload, metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA): await asyncio.sleep(5) diff --git a/examples/client_with_routing.py b/examples/client_with_routing.py index dfb70c45..9a2b5b3f 100644 --- a/examples/client_with_routing.py +++ b/examples/client_with_routing.py @@ -8,6 +8,7 @@ from rsocket.extensions.helpers import route, composite, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.fragment import Fragment +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator @@ -144,11 +145,9 @@ async def request_fragmented_stream(socket: RSocketClient): async def main(): - async def transport_provider(): - connection = await asyncio.open_connection('localhost', 6565) - yield TransportTCP(*connection) + connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(transport_provider(), + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client: await request_response(client) await request_stream(client) diff --git a/examples/run_against_example_java_server.py b/examples/run_against_example_java_server.py index 6dd4c0c9..916dcf2b 100644 --- a/examples/run_against_example_java_server.py +++ b/examples/run_against_example_java_server.py @@ -9,6 +9,7 @@ from rsocket.extensions.composite_metadata import CompositeMetadata from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.extensions.routing import RoutingMetadata +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -33,11 +34,9 @@ def on_complete(self): def on_error(self, exception: Exception): completion_event.set() - async def transport_provider(): - connection = await asyncio.open_connection('localhost', 6565) - yield TransportTCP(*connection) + connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(transport_provider(), + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA.value.name, data_encoding=WellKnownMimeTypes.APPLICATION_JSON.value.name) as client: metadata = CompositeMetadata() diff --git a/rsocket/helpers.py b/rsocket/helpers.py index 7d5d655d..ba12548a 100644 --- a/rsocket/helpers.py +++ b/rsocket/helpers.py @@ -60,3 +60,7 @@ def wrap_transport_exception(): yield except Exception as exception: raise RSocketTransportError from exception + + +def single_transport_provider(transport): + yield transport From e9220b6ff0d44ca81ae9462d4498c46cfd9fbc6c Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:25:34 +0200 Subject: [PATCH 07/10] code cleanup --- examples/client.py | 2 +- rsocket/transports/aiohttp_websocket.py | 11 +++-------- rsocket/transports/quart_websocket.py | 3 --- rsocket/transports/tcp.py | 3 --- rsocket/transports/transport.py | 3 +-- 5 files changed, 5 insertions(+), 17 deletions(-) diff --git a/examples/client.py b/examples/client.py index f2e05be3..0837287c 100644 --- a/examples/client.py +++ b/examples/client.py @@ -16,7 +16,7 @@ def on_next(self, value, is_complete=False): async def main(): - connection = await asyncio.open_connection('localhost', 7000) + connection = await asyncio.open_connection('localhost', 6565) async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: payload = Payload(b'%Y-%m-%d %H:%M:%S') diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index ebdb4b8e..0a1778de 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -5,7 +5,7 @@ from aiohttp import web from rsocket.frame import Frame -from rsocket.helpers import wrap_transport_exception +from rsocket.helpers import wrap_transport_exception, single_transport_provider from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.abstract_websocket import AbstractWebsocketTransport @@ -13,10 +13,8 @@ @asynccontextmanager async def websocket_client(url, *args, **kwargs) -> RSocketClient: - async def transport_provider(): - yield TransportAioHttpClient(url) - - async with RSocketClient(transport_provider(), *args, **kwargs) as client: + async with RSocketClient(single_transport_provider(TransportAioHttpClient(url)), + *args, **kwargs) as client: yield client @@ -71,9 +69,6 @@ def __init__(self, websocket): super().__init__() self._ws = websocket - async def connect(self): - pass - async def _message_generator(self): with wrap_transport_exception(): async for msg in self._ws: diff --git a/rsocket/transports/quart_websocket.py b/rsocket/transports/quart_websocket.py index 3cefc935..7fdf0a65 100644 --- a/rsocket/transports/quart_websocket.py +++ b/rsocket/transports/quart_websocket.py @@ -20,9 +20,6 @@ async def websocket_handler(*args, on_server_create=None, **kwargs): class TransportQuartWebsocket(AbstractWebsocketTransport): - async def connect(self): - pass - async def handle_incoming_ws_messages(self): try: while True: diff --git a/rsocket/transports/tcp.py b/rsocket/transports/tcp.py index aed9d8b6..376123c2 100644 --- a/rsocket/transports/tcp.py +++ b/rsocket/transports/tcp.py @@ -11,9 +11,6 @@ def __init__(self, reader: StreamReader, writer: StreamWriter): self._writer = writer self._reader = reader - async def connect(self): - pass - async def send_frame(self, frame: Frame): with wrap_transport_exception(): self._writer.write(serialize_with_frame_size_header(frame)) diff --git a/rsocket/transports/transport.py b/rsocket/transports/transport.py index 030b486d..9ef41777 100644 --- a/rsocket/transports/transport.py +++ b/rsocket/transports/transport.py @@ -9,9 +9,8 @@ class Transport(metaclass=abc.ABCMeta): def __init__(self): self._frame_parser = FrameParser() - @abc.abstractmethod async def connect(self): - ... + """"Optional if required""" @abc.abstractmethod async def send_frame(self, frame: Frame): From d50d3739c7e0babfcfef7fa20ac691a88f302209 Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:31:14 +0200 Subject: [PATCH 08/10] removed unneeded asyncio event property fix single_transport_provider --- rsocket/helpers.py | 2 +- rsocket/rsocket_base.py | 8 +------- rsocket/rsocket_client.py | 1 - rsocket/rsocket_server.py | 2 -- rsocket/transports/aiohttp_websocket.py | 10 +++++++--- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/rsocket/helpers.py b/rsocket/helpers.py index ba12548a..709805c7 100644 --- a/rsocket/helpers.py +++ b/rsocket/helpers.py @@ -62,5 +62,5 @@ def wrap_transport_exception(): raise RSocketTransportError from exception -def single_transport_provider(transport): +async def single_transport_provider(transport): yield transport diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index 74858a85..528f9c1e 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -1,6 +1,6 @@ import abc import asyncio -from asyncio import Future, QueueEmpty, Task, Event +from asyncio import Future, QueueEmpty, Task from datetime import timedelta from typing import Union, Optional, Dict, Any, Coroutine, Callable, Type, cast, TypeVar @@ -81,7 +81,6 @@ def __init__(self, self._lease_publisher = lease_publisher self._sender_task = None self._receiver_task = None - self._transport_ready = Event() self._async_frame_handler_by_type: Dict[Type[Frame], Any] = { RequestResponseFrame: self.handle_request_response, @@ -142,7 +141,6 @@ async def connect(self): if self._honor_lease: self._subscribe_to_lease_publisher() - self._transport_ready.set() return self def _ensure_encoding_name(self, encoding) -> bytes: @@ -342,7 +340,6 @@ async def handle_lease(self, frame: LeaseFrame): async def _receiver(self): try: - await self._transport_ready.wait() await self._receiver_listen() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: receiver', self._log_identifier()) @@ -354,7 +351,6 @@ async def _receiver(self): async def _on_connection_lost(self, exception: Exception): logger().debug(str(exception)) - self._transport_ready.clear() await self._handler.on_connection_lost(self, exception) @abc.abstractmethod @@ -423,7 +419,6 @@ async def _finally_sender(self): async def _sender(self): try: - await self._transport_ready.wait() try: transport = await self._current_transport() @@ -454,7 +449,6 @@ async def _sender(self): async def close(self): logger().debug('%s: Closing', self._log_identifier()) - self._transport_ready.clear() self._is_closing = True await self._cancel_if_task_exists(self._sender_task) await self._cancel_if_task_exists(self._receiver_task) diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index 1ebbcc5c..52d63fcd 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -54,7 +54,6 @@ def _log_identifier(self) -> str: async def connect(self): logger().debug('%s: connecting', self._log_identifier()) - self._transport_ready.clear() if self._current_transport().done(): await self.close() diff --git a/rsocket/rsocket_server.py b/rsocket/rsocket_server.py index a5860cef..a5cb5449 100644 --- a/rsocket/rsocket_server.py +++ b/rsocket/rsocket_server.py @@ -1,4 +1,3 @@ -import asyncio from asyncio import Future from datetime import timedelta from typing import Optional, Union, Callable @@ -35,7 +34,6 @@ def __init__(self, max_lifetime_period, setup_payload) self._transport = transport - self._transport_ready.set() def _current_transport(self) -> Future: return create_future(self._transport) diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index 0a1778de..39739607 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -6,6 +6,7 @@ from rsocket.frame import Frame from rsocket.helpers import wrap_transport_exception, single_transport_provider +from rsocket.logger import logger from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.abstract_websocket import AbstractWebsocketTransport @@ -76,9 +77,12 @@ async def _message_generator(self): yield msg.data async def handle_incoming_ws_messages(self): - async for message in self._message_generator(): - async for frame in self._frame_parser.receive_data(message, 0): - self._incoming_frame_queue.put_nowait(frame) + try: + async for message in self._message_generator(): + async for frame in self._frame_parser.receive_data(message, 0): + self._incoming_frame_queue.put_nowait(frame) + except asyncio.CancelledError: + logger().debug('Asyncio task canceled: aiohttp_handle_incoming_ws_messages') async def send_frame(self, frame: Frame): with wrap_transport_exception(): From 410c7779aadd9d2d43306305779b7ffe68cb5bcc Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:38:07 +0200 Subject: [PATCH 09/10] refactoring removed unused code --- rsocket/rsocket_client.py | 3 -- rsocket/rsocket_server.py | 3 -- tests/conftest.py | 11 +++--- tests/rsocket/test_frame.py | 56 +++++++++++++++--------------- tests/rsocket/test_frame_decode.py | 4 +-- 5 files changed, 34 insertions(+), 43 deletions(-) diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index 52d63fcd..5222eee1 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -142,6 +142,3 @@ async def _receiver_listen(self): await super()._receiver_listen() finally: await self._cancel_if_task_exists(keepalive_timeout_task) - - async def _sender(self): - return await super()._sender() diff --git a/rsocket/rsocket_server.py b/rsocket/rsocket_server.py index a5cb5449..016666f0 100644 --- a/rsocket/rsocket_server.py +++ b/rsocket/rsocket_server.py @@ -50,6 +50,3 @@ def _get_first_stream_id(self) -> int: def is_server_alive(self) -> bool: return True - - async def _sender(self): - return await super()._sender() diff --git a/tests/conftest.py b/tests/conftest.py index 544334b1..3e36ce1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from quart import Quart from rsocket.frame_parser import FrameParser +from rsocket.helpers import single_transport_provider from rsocket.rsocket_base import RSocketBase from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer @@ -95,18 +96,14 @@ def session(*connection): async def start(): nonlocal service, client service = await asyncio.start_server(session, host, port) - + connection = await asyncio.open_connection(host, port) nonlocal client_arguments # test_overrides = {'keep_alive_period': timedelta(minutes=20)} client_arguments = client_arguments or {} # client_arguments.update(test_overrides) - async def transport_provider(): - connection = await asyncio.open_connection(host, port) - yield TransportTCP(*connection) - - client = RSocketClient(transport_provider(), **(client_arguments or {})) + client = RSocketClient(single_transport_provider(TransportTCP(*connection)), **(client_arguments or {})) if auto_connect_client: await client.connect() @@ -144,7 +141,7 @@ async def server_provider(): @pytest.fixture -def connection(): +def frame_parser(): return FrameParser() diff --git a/tests/rsocket/test_frame.py b/tests/rsocket/test_frame.py index 8c290508..fb2ce898 100644 --- a/tests/rsocket/test_frame.py +++ b/tests/rsocket/test_frame.py @@ -20,7 +20,7 @@ (1, b'\x04\x05\x06\x07\x08', 1, b''), (1, b'\x04\x05\x06\x07\x08', 0, b''), )) -async def test_setup_readable(connection, metadata_flag, metadata, lease, data): +async def test_setup_readable(frame_parser, metadata_flag, metadata, lease, data): def variable_length(): length = len(data) if metadata_flag != 0: @@ -65,7 +65,7 @@ def variable_length(): frame_data = build_frame(*items) - frames = await asyncstdlib.builtins.list(connection.receive_data(frame_data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(frame_data)) frame = frames[0] assert isinstance(frame, SetupFrame) assert serialize_with_frame_size_header(frame) == frame_data @@ -88,7 +88,7 @@ def variable_length(): (0), (1) )) -async def test_setup_with_resume(connection, lease): +async def test_setup_with_resume(frame_parser, lease): data = build_frame( bits(24, 84, 'Frame size'), bits(1, 0, 'Padding'), @@ -124,7 +124,7 @@ async def test_setup_with_resume(connection, lease): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, SetupFrame) assert serialize_with_frame_size_header(frame) == data @@ -147,7 +147,7 @@ async def test_setup_with_resume(connection, lease): 0, 1 )) -async def test_request_stream_frame(connection, follows): +async def test_request_stream_frame(frame_parser, follows): data = build_frame( bits(24, 32, 'Frame size'), bits(1, 0, 'Padding'), @@ -173,7 +173,7 @@ async def test_request_stream_frame(connection, follows): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestStreamFrame) assert serialize_with_frame_size_header(frame) == data @@ -199,7 +199,7 @@ async def test_request_stream_frame(connection, follows): (0, 0), (1, 0) )) -async def test_request_channel_frame(connection, follows, complete): +async def test_request_channel_frame(frame_parser, follows, complete): data = build_frame( bits(24, 32, 'Frame size'), bits(1, 0, 'Padding'), @@ -226,7 +226,7 @@ async def test_request_channel_frame(connection, follows, complete): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestChannelFrame) assert serialize_with_frame_size_header(frame) == data @@ -265,7 +265,7 @@ def test_basic_composite_metadata_item(): assert data == serialized -async def test_request_with_composite_metadata(connection): +async def test_request_with_composite_metadata(frame_parser): data = build_frame( bits(24, 28, 'Frame size'), bits(1, 0, 'Padding'), @@ -289,7 +289,7 @@ async def test_request_with_composite_metadata(connection): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestResponseFrame) assert serialize_with_frame_size_header(frame) == data @@ -359,7 +359,7 @@ async def test_composite_metadata_multiple_items(): assert composite_metadata.serialize() == data -async def test_cancel(connection): +async def test_cancel(frame_parser): data = build_frame( bits(24, 6, 'Frame size'), bits(1, 0, 'Padding'), @@ -371,14 +371,14 @@ async def test_cancel(connection): bits(8, 0, 'Padding flags'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, CancelFrame) assert frame.frame_type is FrameType.CANCEL assert serialize_with_frame_size_header(frame) == data -async def test_error(connection): +async def test_error(frame_parser): data = build_frame( bits(24, 20, 'Frame size'), bits(1, 0, 'Padding'), @@ -393,7 +393,7 @@ async def test_error(connection): data_bits(b'error data') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ErrorFrame) assert serialize_with_frame_size_header(frame) == data @@ -403,7 +403,7 @@ async def test_error(connection): assert frame.frame_type is FrameType.ERROR -async def test_request_n_frame(connection): +async def test_request_n_frame(frame_parser): data = build_frame( bits(24, 10, 'Frame size'), bits(1, 0, 'Padding'), @@ -418,7 +418,7 @@ async def test_request_n_frame(connection): bits(31, 23, 'Number of frames to request'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestNFrame) assert serialize_with_frame_size_header(frame) == data @@ -427,7 +427,7 @@ async def test_request_n_frame(connection): assert frame.frame_type is FrameType.REQUEST_N -async def test_resume_frame(connection): +async def test_resume_frame(frame_parser): data = build_frame( bits(24, 40, 'Frame size'), bits(1, 0, 'Padding'), @@ -449,7 +449,7 @@ async def test_resume_frame(connection): bits(63, 456, 'first client position'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ResumeFrame) assert frame.last_server_position == 123 @@ -458,7 +458,7 @@ async def test_resume_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_metadata_push_frame(connection): +async def test_metadata_push_frame(frame_parser): data = build_frame( bits(24, 14, 'Frame size'), bits(1, 0, 'Padding'), @@ -471,7 +471,7 @@ async def test_metadata_push_frame(connection): data_bits(b'metadata') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, MetadataPushFrame) assert frame.metadata == b'metadata' @@ -479,7 +479,7 @@ async def test_metadata_push_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_payload_frame(connection): +async def test_payload_frame(frame_parser): data = build_frame( bits(24, 28, 'Frame size'), bits(1, 0, 'Padding'), @@ -494,7 +494,7 @@ async def test_payload_frame(connection): data_bits(b'actual_data'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, PayloadFrame) assert frame.metadata == b'metadata' @@ -504,7 +504,7 @@ async def test_payload_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_lease_frame(connection): +async def test_lease_frame(frame_parser): data = build_frame( bits(24, 37, 'Frame size'), bits(1, 0, 'Padding'), @@ -521,7 +521,7 @@ async def test_lease_frame(connection): data_bits(b'Metadata on lease frame') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, LeaseFrame) assert frame.number_of_requests == 123 @@ -531,7 +531,7 @@ async def test_lease_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_resume_ok_frame(connection): +async def test_resume_ok_frame(frame_parser): data = build_frame( bits(24, 14, 'Frame size'), bits(1, 0, 'Padding'), @@ -545,7 +545,7 @@ async def test_resume_ok_frame(connection): bits(63, 456, 'Last Received Client Position') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ResumeOKFrame) assert frame.last_received_client_position == 456 @@ -553,7 +553,7 @@ async def test_resume_ok_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_keepalive_frame(connection): +async def test_keepalive_frame(frame_parser): data = build_frame( bits(24, 29, 'Frame size'), bits(1, 0, 'Padding'), @@ -568,7 +568,7 @@ async def test_keepalive_frame(connection): data_bits(b'additional data') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, KeepAliveFrame) diff --git a/tests/rsocket/test_frame_decode.py b/tests/rsocket/test_frame_decode.py index 779661ea..0a743f0d 100644 --- a/tests/rsocket/test_frame_decode.py +++ b/tests/rsocket/test_frame_decode.py @@ -17,12 +17,12 @@ async def test_decode_spring_demo_auth(): assert authentication.password == b'pass' -async def test_multiple_frames(connection): +async def test_multiple_frames(frame_parser): data = b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x13\x00\x00\x26\x6a\x2c\x00\x00\x00\x02\x04\x77\x65\x69' data += b'\x72\x64\x6e\x65\x73\x73' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) assert len(frames) == 5 From 9bca41d975fe32135cb4c3f87863646bd6f85e12 Mon Sep 17 00:00:00 2001 From: gabi Date: Tue, 15 Mar 2022 14:39:53 +0200 Subject: [PATCH 10/10] refactoring removed unused code --- tests/rsocket/test_frame_decode.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/rsocket/test_frame_decode.py b/tests/rsocket/test_frame_decode.py index 0a743f0d..fa7b32b3 100644 --- a/tests/rsocket/test_frame_decode.py +++ b/tests/rsocket/test_frame_decode.py @@ -1,5 +1,9 @@ +from typing import cast + import asyncstdlib +from rsocket.extensions.authentication import AuthenticationSimple +from rsocket.extensions.authentication_content import AuthenticationContent from rsocket.extensions.composite_metadata import CompositeMetadata @@ -12,7 +16,8 @@ async def test_decode_spring_demo_auth(): assert len(composite_metadata.items) == 2 - authentication = composite_metadata.items[1].authentication + composite_item = cast(AuthenticationContent, composite_metadata.items[1]) + authentication = cast(AuthenticationSimple, composite_item.authentication) assert authentication.username == b'user' assert authentication.password == b'pass'