diff --git a/examples/client.py b/examples/client.py index 31d87a0b..0837287c 100644 --- a/examples/client.py +++ b/examples/client.py @@ -2,6 +2,7 @@ import logging from reactivestreams.subscriber import DefaultSubscriber +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -17,7 +18,7 @@ def on_next(self, value, is_complete=False): async def main(): connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(TransportTCP(*connection)) as client: + async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client: payload = Payload(b'%Y-%m-%d %H:%M:%S') async def run_request_response(): diff --git a/examples/client_springboot.py b/examples/client_springboot.py index c009f756..a15e762f 100644 --- a/examples/client_springboot.py +++ b/examples/client_springboot.py @@ -4,9 +4,10 @@ from uuid import uuid4 from reactivestreams.subscriber import DefaultSubscriber +from rsocket.extensions.helpers import composite, route, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload -from rsocket.extensions.helpers import composite, route, authenticate_simple from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -24,7 +25,8 @@ async def main(): setup_payload = Payload( data=str(uuid4()).encode(), metadata=composite(route('shell-client'), authenticate_simple('user', 'pass'))) - async with RSocketClient(TransportTCP(*connection), + + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), setup_payload=setup_payload, metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA): await asyncio.sleep(5) diff --git a/examples/client_with_routing.py b/examples/client_with_routing.py index ce627084..9a2b5b3f 100644 --- a/examples/client_with_routing.py +++ b/examples/client_with_routing.py @@ -5,10 +5,11 @@ from reactivestreams.subscriber import Subscriber from reactivestreams.subscription import Subscription +from rsocket.extensions.helpers import route, composite, authenticate_simple from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.fragment import Fragment +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload -from rsocket.extensions.helpers import route, composite, authenticate_simple from rsocket.rsocket_client import RSocketClient from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator from rsocket.transports.tcp import TransportTCP @@ -143,9 +144,10 @@ async def request_fragmented_stream(socket: RSocketClient): async def main(): + connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(TransportTCP(*connection), + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client: await request_response(client) await request_stream(client) diff --git a/examples/run_against_example_java_server.py b/examples/run_against_example_java_server.py index b25f0420..916dcf2b 100644 --- a/examples/run_against_example_java_server.py +++ b/examples/run_against_example_java_server.py @@ -9,6 +9,7 @@ from rsocket.extensions.composite_metadata import CompositeMetadata from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.extensions.routing import RoutingMetadata +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload from rsocket.rsocket_client import RSocketClient from rsocket.transports.tcp import TransportTCP @@ -34,7 +35,8 @@ def on_error(self, exception: Exception): completion_event.set() connection = await asyncio.open_connection('localhost', 6565) - async with RSocketClient(TransportTCP(*connection), + + async with RSocketClient(single_transport_provider(TransportTCP(*connection)), metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA.value.name, data_encoding=WellKnownMimeTypes.APPLICATION_JSON.value.name) as client: metadata = CompositeMetadata() diff --git a/rsocket/awaitable/awaitable_rsocket.py b/rsocket/awaitable/awaitable_rsocket.py index f09af8e5..bfa2d57e 100644 --- a/rsocket/awaitable/awaitable_rsocket.py +++ b/rsocket/awaitable/awaitable_rsocket.py @@ -47,8 +47,8 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self._rsocket.__aexit__(exc_type, exc_val, exc_tb) - def connect(self): - return self._rsocket.connect() + async def connect(self): + return await self._rsocket.connect() def close(self): self._rsocket.close() diff --git a/rsocket/exceptions.py b/rsocket/exceptions.py index d6a033a0..22e702f1 100644 --- a/rsocket/exceptions.py +++ b/rsocket/exceptions.py @@ -34,11 +34,11 @@ class RSocketStreamAllocationFailure(RSocketError): pass -class RSocketValueErrorException(RSocketError): +class RSocketValueError(RSocketError): pass -class RSocketProtocolException(RSocketError): +class RSocketProtocolError(RSocketError): def __init__(self, error_code: ErrorCode, data: Optional[str] = None): self.error_code = error_code self.data = data @@ -47,7 +47,7 @@ def __str__(self) -> str: return 'RSocket error %s(%s): "%s"' % (self.error_code.name, self.error_code.value, self.data or '') -class RSocketStreamIdInUse(RSocketProtocolException): +class RSocketStreamIdInUse(RSocketProtocolError): def __init__(self, stream_id: int): super().__init__(ErrorCode.REJECTED) @@ -56,3 +56,11 @@ def __init__(self, stream_id: int): class RSocketFrameFragmentDifferentType(RSocketError): pass + + +class RSocketTransportError(RSocketError): + pass + + +class RSocketNoAvailableTransport(RSocketError): + pass diff --git a/rsocket/frame.py b/rsocket/frame.py index 45fae209..72695db4 100644 --- a/rsocket/frame.py +++ b/rsocket/frame.py @@ -5,7 +5,7 @@ from typing import Tuple, Optional from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException, ParseError, RSocketUnknownFrameType +from rsocket.exceptions import RSocketProtocolError, ParseError, RSocketUnknownFrameType from rsocket.frame_helpers import is_flag_set, unpack_position, pack_position, unpack_24bit, pack_24bit, unpack_32bit, \ ensure_bytes @@ -626,7 +626,7 @@ def parse_or_ignore(buffer: bytes) -> Optional[Frame]: return frame except Exception as exception: if not header.flags_ignore: - raise RSocketProtocolException(ErrorCode.CONNECTION_ERROR, str(exception)) from exception + raise RSocketProtocolError(ErrorCode.CONNECTION_ERROR, str(exception)) from exception def is_fragmentable_frame(frame: Frame) -> bool: @@ -643,7 +643,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame frame = ErrorFrame() frame.stream_id = stream_id - if isinstance(exception, RSocketProtocolException): + if isinstance(exception, RSocketProtocolError): frame.error_code = exception.error_code frame.data = ensure_bytes(exception.data) else: @@ -655,7 +655,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame def error_frame_to_exception(frame: ErrorFrame) -> Exception: if frame.error_code != ErrorCode.APPLICATION_ERROR: - return RSocketProtocolException(frame.error_code, data=frame.data.decode()) + return RSocketProtocolError(frame.error_code, data=frame.data.decode()) return RuntimeError(frame.data.decode('utf-8')) diff --git a/rsocket/helpers.py b/rsocket/helpers.py index 777029bd..709805c7 100644 --- a/rsocket/helpers.py +++ b/rsocket/helpers.py @@ -1,20 +1,22 @@ import asyncio -from typing import Optional +from contextlib import contextmanager +from typing import Optional, Any from reactivestreams.publisher import DefaultPublisher from reactivestreams.subscriber import Subscriber from reactivestreams.subscription import DefaultSubscription +from rsocket.exceptions import RSocketTransportError from rsocket.frame import Frame from rsocket.payload import Payload _default = object() -def create_future(payload: Optional[Payload] = _default) -> asyncio.Future: +def create_future(value: Optional[Any] = _default) -> asyncio.Future: future = asyncio.get_event_loop().create_future() - if payload is not _default: - future.set_result(payload) + if value is not _default: + future.set_result(value) return future @@ -50,3 +52,15 @@ def __eq__(self, other): def __hash__(self): return hash((self.id, self.name)) + + +@contextmanager +def wrap_transport_exception(): + try: + yield + except Exception as exception: + raise RSocketTransportError from exception + + +async def single_transport_provider(transport): + yield transport diff --git a/rsocket/load_balancer/load_balancer_rsocket.py b/rsocket/load_balancer/load_balancer_rsocket.py index d5e8511e..c4e38caa 100644 --- a/rsocket/load_balancer/load_balancer_rsocket.py +++ b/rsocket/load_balancer/load_balancer_rsocket.py @@ -30,14 +30,14 @@ def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: def metadata_push(self, metadata: bytes): self._select_client().metadata_push(metadata) - def connect(self): - self._strategy.connect() + async def connect(self): + await self._strategy.connect() async def close(self): await self._strategy.close() async def __aenter__(self) -> RSocket: - self._strategy.connect() + await self._strategy.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/rsocket/load_balancer/load_balancer_strategy.py b/rsocket/load_balancer/load_balancer_strategy.py index 152adfdd..a707775e 100644 --- a/rsocket/load_balancer/load_balancer_strategy.py +++ b/rsocket/load_balancer/load_balancer_strategy.py @@ -9,7 +9,7 @@ def select(self) -> RSocket: ... @abc.abstractmethod - def connect(self): + async def connect(self): ... @abc.abstractmethod diff --git a/rsocket/load_balancer/random_client.py b/rsocket/load_balancer/random_client.py index 9ac6d4f5..16a07b93 100644 --- a/rsocket/load_balancer/random_client.py +++ b/rsocket/load_balancer/random_client.py @@ -20,9 +20,9 @@ def select(self) -> RSocket: random_client_id = random.randint(0, len(self._pool)) return self._pool[random_client_id] - def connect(self): + async def connect(self): if self._auto_connect: - [client.connect() for client in self._pool] + [await client.connect() for client in self._pool] async def close(self): if self._auto_close: diff --git a/rsocket/load_balancer/round_robin.py b/rsocket/load_balancer/round_robin.py index 5b7681be..c8ee1d8c 100644 --- a/rsocket/load_balancer/round_robin.py +++ b/rsocket/load_balancer/round_robin.py @@ -20,9 +20,9 @@ def select(self) -> RSocket: self._current_index = (self._current_index + 1) % len(self._pool) return client - def connect(self): + async def connect(self): if self._auto_connect: - [client.connect() for client in self._pool] + [await client.connect() for client in self._pool] async def close(self): if self._auto_close: diff --git a/rsocket/request_handler.py b/rsocket/request_handler.py index e364f3cc..ad7da3e6 100644 --- a/rsocket/request_handler.py +++ b/rsocket/request_handler.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from asyncio import Future from datetime import timedelta -from typing import Tuple, Optional, Callable +from typing import Tuple, Optional from reactivestreams.publisher import Publisher from reactivestreams.subscriber import Subscriber @@ -58,7 +58,11 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): @abstractmethod async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): + rsocket): + ... + + @abstractmethod + async def on_connection_lost(self, rsocket, exception): ... def _parse_composite_metadata(self, metadata: bytes) -> CompositeMetadata: @@ -93,7 +97,10 @@ async def request_stream(self, payload: Payload) -> Publisher: async def on_error(self, error_code: ErrorCode, payload: Payload): logger().error('Error: %s, %s', error_code, payload) + async def on_connection_lost(self, rsocket, exception: Exception): + await rsocket.close() + async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): + rsocket): pass diff --git a/rsocket/rsocket.py b/rsocket/rsocket.py index 78ef4448..b3a80b6b 100644 --- a/rsocket/rsocket.py +++ b/rsocket/rsocket.py @@ -33,7 +33,7 @@ def metadata_push(self, metadata: bytes): ... @abc.abstractmethod - def connect(self): + async def connect(self): ... @abc.abstractmethod diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index 559495e9..528f9c1e 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -8,13 +8,17 @@ from reactivestreams.subscriber import DefaultSubscriber from rsocket.datetime_helpers import to_milliseconds from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError, RSocketTransportError from rsocket.extensions.mimetypes import WellKnownMimeTypes -from rsocket.frame import KeepAliveFrame, \ - MetadataPushFrame, RequestFireAndForgetFrame, RequestResponseFrame, \ - RequestStreamFrame, Frame, exception_to_error_frame, LeaseFrame, ErrorFrame, RequestFrame, \ - initiate_request_frame_types, InvalidFrame, FragmentableFrame -from rsocket.frame import RequestChannelFrame, ResumeFrame, is_fragmentable_frame, CONNECTION_STREAM_ID +from rsocket.frame import (KeepAliveFrame, + MetadataPushFrame, RequestFireAndForgetFrame, + RequestResponseFrame, RequestStreamFrame, Frame, + exception_to_error_frame, + LeaseFrame, ErrorFrame, RequestFrame, + initiate_request_frame_types, InvalidFrame, + FragmentableFrame) +from rsocket.frame import (RequestChannelFrame, ResumeFrame, + is_fragmentable_frame, CONNECTION_STREAM_ID) from rsocket.frame import SetupFrame from rsocket.frame_builders import to_payload_frame, to_fire_and_forget_frame from rsocket.frame_fragment_cache import FrameFragmentCache @@ -36,7 +40,6 @@ from rsocket.stream_control import StreamControl from rsocket.streams.backpressureapi import BackpressureApi from rsocket.streams.stream_handler import StreamHandler -from rsocket.transports.transport import Transport async def noop_frame_handler(frame): @@ -56,8 +59,7 @@ def on_next(self, value, is_complete=False): self._socket.send_lease(value) def __init__(self, - transport: Transport, - handler_factory: Type[RequestHandler] = BaseRequestHandler, + handler_factory: Callable[['RSocketBase'], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, @@ -68,33 +70,17 @@ def __init__(self, setup_payload: Optional[Payload] = None ): - self._transport = transport - self._handler = handler_factory(self) - self._stream_control = StreamControl(self._get_first_stream_id()) + self._handler_factory = handler_factory + self._request_queue_size = request_queue_size self._honor_lease = honor_lease self._max_lifetime_period = max_lifetime_period self._keep_alive_period = keep_alive_period self._setup_payload = setup_payload - self._data_encoding = self._ensure_encoding_name(data_encoding) self._metadata_encoding = self._ensure_encoding_name(metadata_encoding) - self._is_closing = False - - if self._honor_lease: - self._requester_lease = DefinedLease(maximum_request_count=0) - else: - self._requester_lease = NullLease() - - self._responder_lease = NullLease() self._lease_publisher = lease_publisher - - self._frame_fragment_cache = FrameFragmentCache() - - self._send_queue = asyncio.Queue() - self._request_queue = asyncio.Queue(request_queue_size) - - self._receiver_task = self._start_task_if_not_closing(self._receiver) - self._sender_task = self._start_task_if_not_closing(self._sender) + self._sender_task = None + self._receiver_task = None self._async_frame_handler_by_type: Dict[Type[Frame], Any] = { RequestResponseFrame: self.handle_request_response, @@ -109,12 +95,48 @@ def __init__(self, ErrorFrame: self.handle_error } - def connect(self): + self._setup_internals() + + def _setup_internals(self): + pass + + @abc.abstractmethod + def _current_transport(self) -> Future: + ... + + def _reset_internals(self): + self._frame_fragment_cache = FrameFragmentCache() + self._send_queue = asyncio.Queue() + self._request_queue = asyncio.Queue(self._request_queue_size) + + if self._honor_lease: + self._requester_lease = DefinedLease(maximum_request_count=0) + else: + self._requester_lease = NullLease() + + self._responder_lease = NullLease() + self._stream_control = StreamControl(self._get_first_stream_id()) + self._handler = self._handler_factory(self) + self._is_closing = False + + def close_all_streams(self): + self._stream_control.cancel_all_handlers() + + def _start_tasks(self): + self._receiver_task = self._start_task_if_not_closing(self._receiver) + self._sender_task = self._start_task_if_not_closing(self._sender) + + def _peek_at_queue_next_item(self, queue) -> Optional[Frame]: + if len(queue._queue) > 0: + return queue._queue[0] + return None + + async def connect(self): logger().debug('%s: sending setup frame', self._log_identifier()) - self.send_frame(self._create_setup_frame(self._data_encoding, - self._metadata_encoding, - self._setup_payload)) + self.send_priority_frame(self._create_setup_frame(self._data_encoding, + self._metadata_encoding, + self._setup_payload)) if self._honor_lease: self._subscribe_to_lease_publisher() @@ -131,6 +153,9 @@ def _start_task_if_not_closing(self, task_factory: Callable[[], Coroutine]) -> O if not self._is_closing: return asyncio.create_task(task_factory()) + def set_handler_factory(self, handler_factory): + self._handler_factory = handler_factory + def set_handler_using_factory(self, handler_factory) -> RequestHandler: self._handler = handler_factory(self) return self._handler @@ -156,6 +181,15 @@ def _queue_request_frame(self, frame: RequestFrame): self._request_queue.put_nowait(frame) + def send_priority_frame(self, frame: Frame): + items = [] + while not self._send_queue.empty(): + items.append(self._send_queue.get_nowait()) + + self._send_queue.put_nowait(frame) + for item in items: + self._send_queue.put_nowait(item) + def send_frame(self, frame: Frame): self._send_queue.put_nowait(frame) @@ -230,11 +264,11 @@ async def handle_request_stream(self, frame: RequestStreamFrame): async def handle_setup(self, frame: SetupFrame): if frame.flags_resume: - raise RSocketProtocolException(ErrorCode.UNSUPPORTED_SETUP, data='Resume not supported') + raise RSocketProtocolError(ErrorCode.UNSUPPORTED_SETUP, data='Resume not supported') if frame.flags_lease: if self._lease_publisher is None: - raise RSocketProtocolException(ErrorCode.UNSUPPORTED_SETUP, data='Lease not available') + raise RSocketProtocolError(ErrorCode.UNSUPPORTED_SETUP, data='Lease not available') else: self._subscribe_to_lease_publisher() @@ -246,7 +280,7 @@ async def handle_setup(self, frame: SetupFrame): payload_from_frame(frame)) except Exception as exception: logger().error('%s: Setup error', self._log_identifier(), exc_info=True) - raise RSocketProtocolException(ErrorCode.REJECTED_SETUP, data=str(exception)) from exception + raise RSocketProtocolError(ErrorCode.REJECTED_SETUP, data=str(exception)) from exception def _subscribe_to_lease_publisher(self): if self._lease_publisher is not None: @@ -287,7 +321,7 @@ async def handle_request_channel(self, frame: RequestChannelFrame): channel_responder.frame_received(frame) async def handle_resume(self, frame: ResumeFrame): - raise RSocketProtocolException(ErrorCode.REJECTED_RESUME, data='Resume not supported') + raise RSocketProtocolError(ErrorCode.REJECTED_RESUME, data='Resume not supported') async def handle_lease(self, frame: LeaseFrame): logger().debug('%s: received lease frame', self._log_identifier()) @@ -309,9 +343,15 @@ async def _receiver(self): await self._receiver_listen() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: receiver', self._log_identifier()) - except Exception: + except RSocketTransportError as exception: + await self._on_connection_lost(exception) + except Exception as exception: logger().error('%s: Unknown error', self._log_identifier(), exc_info=True) - raise + await self._on_connection_lost(exception) + + async def _on_connection_lost(self, exception: Exception): + logger().debug(str(exception)) + await self._handler.on_connection_lost(self, exception) @abc.abstractmethod def is_server_alive(self) -> bool: @@ -319,17 +359,19 @@ def is_server_alive(self) -> bool: async def _receiver_listen(self): + transport = await self._current_transport() while self.is_server_alive(): - - next_frame_generator = await self._transport.next_frame_generator(self.is_server_alive()) + next_frame_generator = await transport.next_frame_generator(self.is_server_alive()) if next_frame_generator is None: break async for frame in next_frame_generator: try: await self._handle_next_frame(frame) - except RSocketProtocolException as exception: - logger().error('%s: RSocket Error %s', self._log_identifier(), str(exception)) + except RSocketProtocolError as exception: + logger().error('%s: Protocol Error %s', self._log_identifier(), str(exception)) self.send_error(frame.stream_id, exception) + except RSocketTransportError: + raise except Exception as exception: logger().error('%s: Unknown Error', self._log_identifier(), exc_info=True) self.send_error(frame.stream_id, exception) @@ -376,32 +418,51 @@ async def _finally_sender(self): pass async def _sender(self): - self._before_sender() - try: - while self.is_server_alive(): - frame = await self._send_queue.get() + try: + transport = await self._current_transport() + + self._before_sender() + while self.is_server_alive(): + frame = await self._send_queue.get() + + await transport.send_frame(frame) + self._send_queue.task_done() - await self._transport.send_frame(frame) - self._send_queue.task_done() + if self._send_queue.empty(): + await transport.on_send_queue_empty() + except RSocketTransportError as exception: + await self._on_connection_lost(exception) - if self._send_queue.empty(): - await self._transport.on_send_queue_empty() - except ConnectionResetError as exception: - logger().debug(str(exception)) + except Exception as exception: + logger().error('%s: RSocket error', self._log_identifier(), exc_info=True) + await self._on_connection_lost(exception) + finally: + await self._finally_sender() except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: sender', self._log_identifier()) - except Exception: + except RSocketTransportError as exception: + await self._on_connection_lost(exception) + except Exception as exception: logger().error('%s: RSocket error', self._log_identifier(), exc_info=True) - raise - finally: - await self._finally_sender() + await self._on_connection_lost(exception) async def close(self): + logger().debug('%s: Closing', self._log_identifier()) self._is_closing = True await self._cancel_if_task_exists(self._sender_task) await self._cancel_if_task_exists(self._receiver_task) - await self._transport.close() + + if self._current_transport().done(): + logger().debug('%s: Closing transport', self._log_identifier()) + transport = self._current_transport().result() + + if transport is not None: + try: + await transport.close() + except Exception as exception: + logger().debug('Transport already closed or failed to close: %s', str(exception)) + pass async def _cancel_if_task_exists(self, task): if task is not None: @@ -409,7 +470,7 @@ async def _cancel_if_task_exists(self, task): try: await task except asyncio.CancelledError: - logger().debug('%s: Asyncio task canceled', self._log_identifier()) + logger().debug('%s: Asyncio task canceled: %s', self._log_identifier(), str(task)) async def __aenter__(self) -> 'RSocketBase': return self @@ -418,21 +479,21 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() def request_response(self, payload: Payload) -> Future: - logger().debug('%s: sending request-response: %s', self._log_identifier(), payload) + logger().debug('%s: request-response: %s', self._log_identifier(), payload) requester = RequestResponseRequester(self, payload) self.register_new_stream(requester).setup() return requester.run() def fire_and_forget(self, payload: Payload): - logger().debug('%s: sending fire-and-forget: %s', self._log_identifier(), payload) + logger().debug('%s: fire-and-forget: %s', self._log_identifier(), payload) stream_id = self._allocate_stream() self.send_request(to_fire_and_forget_frame(stream_id, payload)) self.finish_stream(stream_id) def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]: - logger().debug('%s: sending request-stream: %s', self._log_identifier(), payload) + logger().debug('%s: request-stream: %s', self._log_identifier(), payload) requester = RequestStreamRequester(self, payload) self.register_new_stream(requester) @@ -442,14 +503,15 @@ def request_channel( self, payload: Payload, local_publisher: Optional[Publisher] = None) -> Union[BackpressureApi, Publisher]: - - logger().debug('%s: sending request-channel: %s', self._log_identifier(), payload) + logger().debug('%s: request-channel: %s', self._log_identifier(), payload) requester = RequestChannelRequester(self, payload, local_publisher) self.register_new_stream(requester) return requester def metadata_push(self, metadata: bytes): + logger().debug('%s: metadata-push: %s', self._log_identifier(), metadata) + frame = MetadataPushFrame() frame.metadata = metadata self.send_frame(frame) diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index df811aa2..5222eee1 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -1,10 +1,13 @@ import asyncio +from asyncio import Future from datetime import timedelta, datetime -from typing import Optional, Type +from typing import Optional, Callable, AsyncGenerator, Any from typing import Union from reactivestreams.publisher import Publisher +from rsocket.exceptions import RSocketNoAvailableTransport, RSocketTransportError from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import create_future from rsocket.logger import logger from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler @@ -16,8 +19,8 @@ class RSocketClient(RSocketBase): def __init__(self, - transport: Transport, - handler_factory: Type[RequestHandler] = BaseRequestHandler, + transport_provider: AsyncGenerator[Transport, Any], + handler_factory: Callable[[RSocketBase], RequestHandler] = BaseRequestHandler, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, @@ -25,13 +28,15 @@ def __init__(self, metadata_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, keep_alive_period: timedelta = timedelta(milliseconds=500), max_lifetime_period: timedelta = timedelta(minutes=10), - setup_payload: Optional[Payload] = None + setup_payload: Optional[Payload] = None, ): + self._transport_provider = transport_provider.__aiter__() self._is_server_alive = True self._update_last_keepalive() + self._transport: Optional[Transport] = None + self._next_transport = asyncio.Future() - super().__init__(transport, - handler_factory=handler_factory, + super().__init__(handler_factory=handler_factory, honor_lease=honor_lease, lease_publisher=lease_publisher, request_queue_size=request_queue_size, @@ -41,11 +46,54 @@ def __init__(self, max_lifetime_period=max_lifetime_period, setup_payload=setup_payload) + def _current_transport(self) -> Future: + return self._next_transport + def _log_identifier(self) -> str: return 'client' + async def connect(self): + logger().debug('%s: connecting', self._log_identifier()) + + if self._current_transport().done(): + await self.close() + + self._is_closing = False + self._reset_internals() + self._start_tasks() + + try: + new_transport = await self._get_new_transport() + except Exception as exception: + logger().error('%s: Connection error', self._log_identifier(), exc_info=True) + self._handler.on_connection_lost(self, exception) + return + + if new_transport is None: + raise RSocketNoAvailableTransport() + + self._next_transport.set_result(new_transport) + transport = await self._current_transport() + + try: + await transport.connect() + except Exception as exception: + raise RSocketTransportError from exception + + return await super().connect() + + async def _get_new_transport(self): + try: + return await self._transport_provider.__anext__() + except StopAsyncIteration: + return + + async def close(self): + await super().close() + self._next_transport = create_future() + async def __aenter__(self) -> 'RSocketClient': - self.connect() + await self.connect() return self def _get_first_stream_id(self) -> int: @@ -82,7 +130,7 @@ async def _keepalive_timeout_task(self): self._is_server_alive = False await self._handler.on_keepalive_timeout( time_since_last_keepalive, - self._stream_control.cancel_all_handlers + self ) except asyncio.CancelledError: logger().debug('%s: Asyncio task canceled: keepalive_timeout', self._log_identifier()) diff --git a/rsocket/rsocket_internal.py b/rsocket/rsocket_internal.py index bdd38430..091a4085 100644 --- a/rsocket/rsocket_internal.py +++ b/rsocket/rsocket_internal.py @@ -28,3 +28,7 @@ def send_payload(self, stream_id: int, payload: Payload, complete=False, is_next @abc.abstractmethod def send_error(self, stream_id: int, exception: Exception): ... + + @abc.abstractmethod + def close_all_streams(self): + ... diff --git a/rsocket/rsocket_server.py b/rsocket/rsocket_server.py index e62cd346..016666f0 100644 --- a/rsocket/rsocket_server.py +++ b/rsocket/rsocket_server.py @@ -1,8 +1,47 @@ +from asyncio import Future +from datetime import timedelta +from typing import Optional, Union, Callable + +from reactivestreams.publisher import Publisher +from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.helpers import create_future +from rsocket.payload import Payload +from rsocket.request_handler import RequestHandler, BaseRequestHandler from rsocket.rsocket_base import RSocketBase +from rsocket.transports.transport import Transport class RSocketServer(RSocketBase): + def __init__(self, + transport: Transport, + handler_factory: Callable[[RSocketBase], RequestHandler] = BaseRequestHandler, + honor_lease=False, + lease_publisher: Optional[Publisher] = None, + request_queue_size: int = 0, + data_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + metadata_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + keep_alive_period: timedelta = timedelta(milliseconds=500), + max_lifetime_period: timedelta = timedelta(minutes=10), + setup_payload: Optional[Payload] = None): + super().__init__(handler_factory, + honor_lease, + lease_publisher, + request_queue_size, + data_encoding, + metadata_encoding, + keep_alive_period, + max_lifetime_period, + setup_payload) + self._transport = transport + + def _current_transport(self) -> Future: + return create_future(self._transport) + + def _setup_internals(self): + self._reset_internals() + self._start_tasks() + def _log_identifier(self) -> str: return 'server' diff --git a/rsocket/rx_support/rx_rsocket.py b/rsocket/rx_support/rx_rsocket.py index f394c970..c3073ac8 100644 --- a/rsocket/rx_support/rx_rsocket.py +++ b/rsocket/rx_support/rx_rsocket.py @@ -41,8 +41,8 @@ def fire_and_forget(self, request: Payload): def metadata_push(self, metadata: bytes): self._rsocket.metadata_push(metadata) - def connect(self): - return self._rsocket.connect() + async def connect(self): + return await self._rsocket.connect() async def close(self): await self._rsocket.close() diff --git a/rsocket/streams/stream_handler.py b/rsocket/streams/stream_handler.py index 98aa69f9..8063542a 100644 --- a/rsocket/streams/stream_handler.py +++ b/rsocket/streams/stream_handler.py @@ -1,7 +1,7 @@ from abc import abstractmethod, ABCMeta from typing import Optional -from rsocket.exceptions import RSocketValueErrorException +from rsocket.exceptions import RSocketValueError from rsocket.frame import Frame, MAX_REQUEST_N from rsocket.frame_builders import to_cancel_frame, to_request_n_frame from rsocket.logger import logger @@ -22,7 +22,7 @@ def setup(self): def initial_request_n(self, n: int): if n <= 0: self.socket.finish_stream(self.stream_id) - raise RSocketValueErrorException('Initial request N must be > 0') + raise RSocketValueError('Initial request N must be > 0') self._initial_request_n = n return self diff --git a/rsocket/transports/aiohttp_websocket.py b/rsocket/transports/aiohttp_websocket.py index bba0793d..39739607 100644 --- a/rsocket/transports/aiohttp_websocket.py +++ b/rsocket/transports/aiohttp_websocket.py @@ -5,6 +5,7 @@ from aiohttp import web from rsocket.frame import Frame +from rsocket.helpers import wrap_transport_exception, single_transport_provider from rsocket.logger import logger from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer @@ -13,15 +14,9 @@ @asynccontextmanager async def websocket_client(url, *args, **kwargs) -> RSocketClient: - async with aiohttp.ClientSession() as session: - async with session.ws_connect(url) as ws: - transport = TransportAioHttpWebsocket(ws) - message_handler = asyncio.create_task(transport.handle_incoming_ws_messages()) - async with RSocketClient(transport, *args, **kwargs) as client: - yield client - - message_handler.cancel() - await message_handler + async with RSocketClient(single_transport_provider(TransportAioHttpClient(url)), + *args, **kwargs) as client: + yield client def websocket_handler_factory(*args, on_server_create=None, **kwargs): @@ -40,22 +35,58 @@ async def websocket_handler(request): return websocket_handler -class TransportAioHttpWebsocket(AbstractWebsocketTransport): - def __init__(self, websocket): +class TransportAioHttpClient(AbstractWebsocketTransport): + + def __init__(self, url): super().__init__() - self._ws = websocket + self._url = url + + async def connect(self): + self._session = aiohttp.ClientSession() + self._ws_context = self._session.ws_connect(self._url) + self._ws = await self._ws_context.__aenter__() + self._message_handler = asyncio.create_task(self.handle_incoming_ws_messages()) async def handle_incoming_ws_messages(self): - try: + with wrap_transport_exception(): async for msg in self._ws: if msg.type == aiohttp.WSMsgType.BINARY: async for frame in self._frame_parser.receive_data(msg.data, 0): self._incoming_frame_queue.put_nowait(frame) + + async def send_frame(self, frame: Frame): + with wrap_transport_exception(): + await self._ws.send_bytes(frame.serialize()) + + async def close(self): + await self._ws_context.__aexit__(None, None, None) + await self._session.__aexit__(None, None, None) + self._message_handler.cancel() + await self._message_handler + + +class TransportAioHttpWebsocket(AbstractWebsocketTransport): + def __init__(self, websocket): + super().__init__() + self._ws = websocket + + async def _message_generator(self): + with wrap_transport_exception(): + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.BINARY: + yield msg.data + + async def handle_incoming_ws_messages(self): + try: + async for message in self._message_generator(): + async for frame in self._frame_parser.receive_data(message, 0): + self._incoming_frame_queue.put_nowait(frame) except asyncio.CancelledError: logger().debug('Asyncio task canceled: aiohttp_handle_incoming_ws_messages') async def send_frame(self, frame: Frame): - await self._ws.send_bytes(frame.serialize()) + with wrap_transport_exception(): + await self._ws.send_bytes(frame.serialize()) async def close(self): await self._ws.close() diff --git a/rsocket/transports/tcp.py b/rsocket/transports/tcp.py index f99a16bf..376123c2 100644 --- a/rsocket/transports/tcp.py +++ b/rsocket/transports/tcp.py @@ -1,7 +1,7 @@ from asyncio import StreamReader, StreamWriter from rsocket.frame import Frame, serialize_with_frame_size_header -from rsocket.logger import logger +from rsocket.helpers import wrap_transport_exception from rsocket.transports.transport import Transport @@ -12,24 +12,23 @@ def __init__(self, reader: StreamReader, writer: StreamWriter): self._reader = reader async def send_frame(self, frame: Frame): - self._writer.write(serialize_with_frame_size_header(frame)) + with wrap_transport_exception(): + self._writer.write(serialize_with_frame_size_header(frame)) async def on_send_queue_empty(self): - await self._writer.drain() + with wrap_transport_exception(): + await self._writer.drain() async def close(self): self._writer.close() await self._writer.wait_closed() async def next_frame_generator(self, is_server_alive): - try: + with wrap_transport_exception(): data = await self._reader.read(1024) - except (ConnectionResetError, BrokenPipeError) as exception: - logger().debug(str(exception)) - return # todo: workaround to silence errors on client closing. this needs a better solution. - if not data: - self._writer.close() - return + if not data: + self._writer.close() + return return self._frame_parser.receive_data(data) diff --git a/rsocket/transports/transport.py b/rsocket/transports/transport.py index 29b6badf..9ef41777 100644 --- a/rsocket/transports/transport.py +++ b/rsocket/transports/transport.py @@ -9,6 +9,9 @@ class Transport(metaclass=abc.ABCMeta): def __init__(self): self._frame_parser = FrameParser() + async def connect(self): + """"Optional if required""" + @abc.abstractmethod async def send_frame(self, frame: Frame): ... diff --git a/tests/conftest.py b/tests/conftest.py index d93a762c..3e36ce1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,13 +11,14 @@ from quart import Quart from rsocket.frame_parser import FrameParser -from rsocket.logger import logger +from rsocket.helpers import single_transport_provider from rsocket.rsocket_base import RSocketBase from rsocket.rsocket_client import RSocketClient from rsocket.rsocket_server import RSocketServer from rsocket.transports.aiohttp_websocket import websocket_client, websocket_handler_factory from rsocket.transports.quart_websocket import websocket_handler from rsocket.transports.tcp import TransportTCP +from tests.rsocket.helpers import assert_no_open_streams logging.basicConfig(level=logging.DEBUG) @@ -96,16 +97,16 @@ async def start(): nonlocal service, client service = await asyncio.start_server(session, host, port) connection = await asyncio.open_connection(host, port) - nonlocal client_arguments # test_overrides = {'keep_alive_period': timedelta(minutes=20)} client_arguments = client_arguments or {} + # client_arguments.update(test_overrides) - client = RSocketClient(TransportTCP(*connection), **(client_arguments or {})) + client = RSocketClient(single_transport_provider(TransportTCP(*connection)), **(client_arguments or {})) if auto_connect_client: - client.connect() + await client.connect() async def finish(): if auto_connect_client: @@ -122,26 +123,25 @@ async def finish(): host = 'localhost' await start() - await wait_for_server.wait() + + async def server_provider(): + await wait_for_server.wait() + return server + try: - yield server, client + if auto_connect_client: + await wait_for_server.wait() + yield server, client + else: + yield server_provider, client + assert_no_open_streams(client, server) finally: await finish() -def assert_no_open_streams(client: RSocketBase, server: RSocketBase): - logger().info('Checking for open client streams') - - assert len(client._stream_control._streams) == 0, 'Client has open streams' - - logger().info('Checking for open server streams') - - assert len(server._stream_control._streams) == 0, 'Server has open streams' - - @pytest.fixture -def connection(): +def frame_parser(): return FrameParser() diff --git a/tests/rsocket/helpers.py b/tests/rsocket/helpers.py index a50dda96..458b9f84 100644 --- a/tests/rsocket/helpers.py +++ b/tests/rsocket/helpers.py @@ -1,7 +1,11 @@ from math import ceil +from typing import Type from rsocket.helpers import create_future +from rsocket.logger import logger from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_base import RSocketBase def data_bits(data: bytes, name: str = None): @@ -24,6 +28,35 @@ def bits(bit_count, value, comment) -> str: return f'{value:b}'.zfill(bit_count) -def future_from_request(request: Payload): +def future_from_payload(request: Payload): return create_future(Payload(b'data: ' + request.data, b'meta: ' + request.metadata)) + + +def assert_no_open_streams(client: RSocketBase, server: RSocketBase): + logger().info('Checking for open client streams') + + assert len(client._stream_control._streams) == 0, 'Client has open streams' + + logger().info('Checking for open server streams') + + assert len(server._stream_control._streams) == 0, 'Server has open streams' + + +class IdentifiedHandler(BaseRequestHandler): + def __init__(self, socket, server_id: int): + super().__init__(socket) + self._server_id = server_id + + +class IdentifiedHandlerFactory: + def __init__(self, server_id: int, handler_factory: Type[IdentifiedHandler]): + self._server_id = server_id + self._handler_factory = handler_factory + + def factory(self, socket) -> BaseRequestHandler: + return self._handler_factory(socket, self._server_id) + + +def force_closing_connection(current_connection): + current_connection[1].close() diff --git a/tests/rsocket/misbehaving_rsocket.py b/tests/rsocket/misbehaving_rsocket.py index 95e43458..3892969c 100644 --- a/tests/rsocket/misbehaving_rsocket.py +++ b/tests/rsocket/misbehaving_rsocket.py @@ -3,11 +3,11 @@ class MisbehavingRSocket: - def __init__(self, socket: Transport): - self._socket = socket + def __init__(self, transport: Transport): + self._transport = transport async def send_frame(self, frame: Frame): - await self._socket.send_frame(frame) + await self._transport.send_frame(frame) class BrokenFrame: diff --git a/tests/rsocket/test_connection_lost.py b/tests/rsocket/test_connection_lost.py new file mode 100644 index 00000000..2ff8d2d6 --- /dev/null +++ b/tests/rsocket/test_connection_lost.py @@ -0,0 +1,161 @@ +import asyncio +from asyncio import Event, Future +from asyncio.base_events import Server +from typing import Optional, Tuple + +from rsocket.frame import Frame +from rsocket.logger import logger +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.transports.tcp import TransportTCP +from rsocket.transports.transport import Transport +from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, \ + IdentifiedHandler, force_closing_connection + + +class ServerHandler(IdentifiedHandler): + async def request_response(self, payload: Payload) -> Future: + return future_from_payload(Payload(payload.data + (' server %d' % self._server_id).encode(), payload.metadata)) + + +async def test_connection_lost(unused_tcp_port): + index_iterator = iter(range(1, 3)) + + wait_for_server = Event() + server_connection: Tuple = None + client_connection: Tuple = None + + class ClientHandler(BaseRequestHandler): + async def on_connection_lost(self, rsocket, exception: Exception): + logger().info('Reconnecting') + await rsocket.connect() + + def session(*connection): + nonlocal server, server_connection + server_connection = connection + server = RSocketServer(TransportTCP(*connection), + IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory) + wait_for_server.set() + + async def start(): + nonlocal service, client + service = await asyncio.start_server(session, host, port) + + async def transport_provider(): + while True: + try: + nonlocal client_connection + client_connection = await asyncio.open_connection(host, port) + yield TransportTCP(*client_connection) + except Exception: + logger().error('Client connection error', exc_info=True) + raise + + client = RSocketClient(transport_provider(), handler_factory=ClientHandler) + + service: Optional[Server] = None + server: Optional[RSocketServer] = None + client: Optional[RSocketClient] = None + port = unused_tcp_port + host = 'localhost' + + await start() + + try: + async with client as connection: + await wait_for_server.wait() + wait_for_server.clear() + response1 = await connection.request_response(Payload(b'request 1')) + force_closing_connection(server_connection) + await server.close() # cleanup async tasks from previous server to avoid errors (?) + await wait_for_server.wait() + response2 = await connection.request_response(Payload(b'request 2')) + + assert response1.data == b'data: request 1 server 1' + assert response2.data == b'data: request 2 server 2' + finally: + await server.close() + + service.close() + + +class FailingTransportTCP(Transport): + + async def connect(self): + raise Exception + + async def send_frame(self, frame: Frame): + pass + + async def next_frame_generator(self, is_server_alive): + pass + + async def close(self): + pass + + +async def test_connection_failure(unused_tcp_port): + index_iterator = iter(range(1, 3)) + + wait_for_server = Event() + server_connection: Tuple = None + client_connection: Tuple = None + + class ClientHandler(BaseRequestHandler): + async def on_connection_lost(self, rsocket, exception: Exception): + logger().info('Reconnecting') + await rsocket.connect() + + def session(*connection): + nonlocal server, server_connection + server_connection = connection + server = RSocketServer(TransportTCP(*connection), + IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory) + wait_for_server.set() + + async def start(): + nonlocal service, client + service = await asyncio.start_server(session, host, port) + + async def transport_provider(): + try: + nonlocal client_connection + client_connection = await asyncio.open_connection(host, port) + yield TransportTCP(*client_connection) + + yield FailingTransportTCP() + + client_connection = await asyncio.open_connection(host, port) + yield TransportTCP(*client_connection) + except Exception: + logger().error('Client connection error', exc_info=True) + raise + + client = RSocketClient(transport_provider(), handler_factory=ClientHandler) + + service: Optional[Server] = None + server: Optional[RSocketServer] = None + client: Optional[RSocketClient] = None + port = unused_tcp_port + host = 'localhost' + + await start() + + try: + async with client as connection: + await wait_for_server.wait() + wait_for_server.clear() + response1 = await connection.request_response(Payload(b'request 1')) + force_closing_connection(server_connection) + await server.close() # cleanup async tasks from previous server to avoid errors (?) + await wait_for_server.wait() + response2 = await connection.request_response(Payload(b'request 2')) + + assert response1.data == b'data: request 1 server 1' + assert response2.data == b'data: request 2 server 2' + finally: + await server.close() + + service.close() diff --git a/tests/rsocket/test_frame.py b/tests/rsocket/test_frame.py index 01075694..fb2ce898 100644 --- a/tests/rsocket/test_frame.py +++ b/tests/rsocket/test_frame.py @@ -2,7 +2,7 @@ import pytest from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.extensions.authentication_types import WellKnownAuthenticationTypes from rsocket.extensions.composite_metadata import CompositeMetadata from rsocket.extensions.mimetypes import WellKnownMimeTypes @@ -20,7 +20,7 @@ (1, b'\x04\x05\x06\x07\x08', 1, b''), (1, b'\x04\x05\x06\x07\x08', 0, b''), )) -async def test_setup_readable(connection, metadata_flag, metadata, lease, data): +async def test_setup_readable(frame_parser, metadata_flag, metadata, lease, data): def variable_length(): length = len(data) if metadata_flag != 0: @@ -65,7 +65,7 @@ def variable_length(): frame_data = build_frame(*items) - frames = await asyncstdlib.builtins.list(connection.receive_data(frame_data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(frame_data)) frame = frames[0] assert isinstance(frame, SetupFrame) assert serialize_with_frame_size_header(frame) == frame_data @@ -88,7 +88,7 @@ def variable_length(): (0), (1) )) -async def test_setup_with_resume(connection, lease): +async def test_setup_with_resume(frame_parser, lease): data = build_frame( bits(24, 84, 'Frame size'), bits(1, 0, 'Padding'), @@ -124,7 +124,7 @@ async def test_setup_with_resume(connection, lease): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, SetupFrame) assert serialize_with_frame_size_header(frame) == data @@ -147,7 +147,7 @@ async def test_setup_with_resume(connection, lease): 0, 1 )) -async def test_request_stream_frame(connection, follows): +async def test_request_stream_frame(frame_parser, follows): data = build_frame( bits(24, 32, 'Frame size'), bits(1, 0, 'Padding'), @@ -173,7 +173,7 @@ async def test_request_stream_frame(connection, follows): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestStreamFrame) assert serialize_with_frame_size_header(frame) == data @@ -199,7 +199,7 @@ async def test_request_stream_frame(connection, follows): (0, 0), (1, 0) )) -async def test_request_channel_frame(connection, follows, complete): +async def test_request_channel_frame(frame_parser, follows, complete): data = build_frame( bits(24, 32, 'Frame size'), bits(1, 0, 'Padding'), @@ -226,7 +226,7 @@ async def test_request_channel_frame(connection, follows, complete): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestChannelFrame) assert serialize_with_frame_size_header(frame) == data @@ -265,7 +265,7 @@ def test_basic_composite_metadata_item(): assert data == serialized -async def test_request_with_composite_metadata(connection): +async def test_request_with_composite_metadata(frame_parser): data = build_frame( bits(24, 28, 'Frame size'), bits(1, 0, 'Padding'), @@ -289,7 +289,7 @@ async def test_request_with_composite_metadata(connection): data_bits(b'\x01\x02\x03'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestResponseFrame) assert serialize_with_frame_size_header(frame) == data @@ -359,7 +359,7 @@ async def test_composite_metadata_multiple_items(): assert composite_metadata.serialize() == data -async def test_cancel(connection): +async def test_cancel(frame_parser): data = build_frame( bits(24, 6, 'Frame size'), bits(1, 0, 'Padding'), @@ -371,14 +371,14 @@ async def test_cancel(connection): bits(8, 0, 'Padding flags'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, CancelFrame) assert frame.frame_type is FrameType.CANCEL assert serialize_with_frame_size_header(frame) == data -async def test_error(connection): +async def test_error(frame_parser): data = build_frame( bits(24, 20, 'Frame size'), bits(1, 0, 'Padding'), @@ -393,7 +393,7 @@ async def test_error(connection): data_bits(b'error data') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ErrorFrame) assert serialize_with_frame_size_header(frame) == data @@ -403,7 +403,7 @@ async def test_error(connection): assert frame.frame_type is FrameType.ERROR -async def test_request_n_frame(connection): +async def test_request_n_frame(frame_parser): data = build_frame( bits(24, 10, 'Frame size'), bits(1, 0, 'Padding'), @@ -418,7 +418,7 @@ async def test_request_n_frame(connection): bits(31, 23, 'Number of frames to request'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, RequestNFrame) assert serialize_with_frame_size_header(frame) == data @@ -427,7 +427,7 @@ async def test_request_n_frame(connection): assert frame.frame_type is FrameType.REQUEST_N -async def test_resume_frame(connection): +async def test_resume_frame(frame_parser): data = build_frame( bits(24, 40, 'Frame size'), bits(1, 0, 'Padding'), @@ -449,7 +449,7 @@ async def test_resume_frame(connection): bits(63, 456, 'first client position'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ResumeFrame) assert frame.last_server_position == 123 @@ -458,7 +458,7 @@ async def test_resume_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_metadata_push_frame(connection): +async def test_metadata_push_frame(frame_parser): data = build_frame( bits(24, 14, 'Frame size'), bits(1, 0, 'Padding'), @@ -471,7 +471,7 @@ async def test_metadata_push_frame(connection): data_bits(b'metadata') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, MetadataPushFrame) assert frame.metadata == b'metadata' @@ -479,7 +479,7 @@ async def test_metadata_push_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_payload_frame(connection): +async def test_payload_frame(frame_parser): data = build_frame( bits(24, 28, 'Frame size'), bits(1, 0, 'Padding'), @@ -494,7 +494,7 @@ async def test_payload_frame(connection): data_bits(b'actual_data'), ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, PayloadFrame) assert frame.metadata == b'metadata' @@ -504,7 +504,7 @@ async def test_payload_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_lease_frame(connection): +async def test_lease_frame(frame_parser): data = build_frame( bits(24, 37, 'Frame size'), bits(1, 0, 'Padding'), @@ -521,7 +521,7 @@ async def test_lease_frame(connection): data_bits(b'Metadata on lease frame') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, LeaseFrame) assert frame.number_of_requests == 123 @@ -531,7 +531,7 @@ async def test_lease_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_resume_ok_frame(connection): +async def test_resume_ok_frame(frame_parser): data = build_frame( bits(24, 14, 'Frame size'), bits(1, 0, 'Padding'), @@ -545,7 +545,7 @@ async def test_resume_ok_frame(connection): bits(63, 456, 'Last Received Client Position') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, ResumeOKFrame) assert frame.last_received_client_position == 456 @@ -553,7 +553,7 @@ async def test_resume_ok_frame(connection): assert serialize_with_frame_size_header(frame) == data -async def test_keepalive_frame(connection): +async def test_keepalive_frame(frame_parser): data = build_frame( bits(24, 29, 'Frame size'), bits(1, 0, 'Padding'), @@ -568,7 +568,7 @@ async def test_keepalive_frame(connection): data_bits(b'additional data') ) - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) frame = frames[0] assert isinstance(frame, KeepAliveFrame) @@ -598,5 +598,5 @@ def test_parse_broken_frame_raises_exception(): bits(13, 23, 'Number of frames to request - broken. smaller than 31 bits'), ) - with pytest.raises(RSocketProtocolException): + with pytest.raises(RSocketProtocolError): parse_or_ignore(broken_frame_data) diff --git a/tests/rsocket/test_frame_decode.py b/tests/rsocket/test_frame_decode.py index 779661ea..fa7b32b3 100644 --- a/tests/rsocket/test_frame_decode.py +++ b/tests/rsocket/test_frame_decode.py @@ -1,5 +1,9 @@ +from typing import cast + import asyncstdlib +from rsocket.extensions.authentication import AuthenticationSimple +from rsocket.extensions.authentication_content import AuthenticationContent from rsocket.extensions.composite_metadata import CompositeMetadata @@ -12,17 +16,18 @@ async def test_decode_spring_demo_auth(): assert len(composite_metadata.items) == 2 - authentication = composite_metadata.items[1].authentication + composite_item = cast(AuthenticationContent, composite_metadata.items[1]) + authentication = cast(AuthenticationSimple, composite_item.authentication) assert authentication.username == b'user' assert authentication.password == b'pass' -async def test_multiple_frames(connection): +async def test_multiple_frames(frame_parser): data = b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x13\x00\x00\x26\x6a\x2c\x00\x00\x00\x02\x04\x77\x65\x69' data += b'\x72\x64\x6e\x65\x73\x73' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' data += b'\x00\x00\x06\x00\x00\x00\x7b\x24\x00' - frames = await asyncstdlib.builtins.list(connection.receive_data(data)) + frames = await asyncstdlib.builtins.list(frame_parser.receive_data(data)) assert len(frames) == 5 diff --git a/tests/rsocket/test_lease.py b/tests/rsocket/test_lease.py index 9bde2892..2e02bc76 100644 --- a/tests/rsocket/test_lease.py +++ b/tests/rsocket/test_lease.py @@ -4,11 +4,11 @@ import pytest from reactivestreams.subscriber import Subscriber -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.lease import SingleLeasePublisher, DefinedLease from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload class PeriodicalLeasePublisher(SingleLeasePublisher): @@ -44,7 +44,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def test_request_response_with_server_side_lease_works(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -59,7 +59,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_client_and_server_side_lease_works(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with PeriodicalLeasePublisher( maximum_request_count=2, @@ -88,7 +88,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_lease_too_many_requests(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -106,7 +106,7 @@ async def request_response(self, request: Payload): async def test_request_response_with_lease_client_side_exception_requests_late(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -134,7 +134,7 @@ async def test_server_rejects_all_requests_if_lease_not_supported(lazy_pipe): async def test_request_response_with_lease_server_side_exception(lazy_pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) async with lazy_pipe(client_arguments={'honor_lease': True}, server_arguments={'handler_factory': Handler, @@ -145,5 +145,5 @@ async def request_response(self, request: Payload): response = await client.request_response(Payload(b'dog', b'cat')) assert response == Payload(b'data: dog', b'meta: cat') - with pytest.raises(RSocketProtocolException): + with pytest.raises(RSocketProtocolError): await client.request_response(Payload(b'invalid request')) diff --git a/tests/rsocket/test_load_balancer.py b/tests/rsocket/test_load_balancer.py index 7ef62f84..f1829ffb 100644 --- a/tests/rsocket/test_load_balancer.py +++ b/tests/rsocket/test_load_balancer.py @@ -4,22 +4,17 @@ from rsocket.load_balancer.load_balancer_rsocket import LoadBalancerRSocket from rsocket.load_balancer.round_robin import LoadBalancerRoundRobin from rsocket.payload import Payload -from rsocket.request_handler import BaseRequestHandler from tests.conftest import pipe_factory_tcp -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, IdentifiedHandler -class Handler(BaseRequestHandler): - def __init__(self, socket, server_id: int): - super().__init__(socket) - self.server_id = server_id +class Handler(IdentifiedHandler): async def request_response(self, request: Payload): - return future_from_request(Payload(request.data + (' server %d' % self.server_id).encode(), request.metadata)) + return future_from_payload(Payload(request.data + (' server %d' % self._server_id).encode(), request.metadata)) async def test_load_balancer_round_robin(unused_tcp_port_factory): - servers = [] clients = [] server_count = 3 request_count = 7 @@ -27,11 +22,10 @@ async def test_load_balancer_round_robin(unused_tcp_port_factory): async with AsyncExitStack() as stack: for i in range(server_count): tcp_port = unused_tcp_port_factory() - server, client = await stack.enter_async_context( + _, client = await stack.enter_async_context( pipe_factory_tcp(tcp_port, - server_arguments={'handler_factory': lambda socket: Handler(socket, i)}, + server_arguments={'handler_factory': IdentifiedHandlerFactory(i, Handler).factory}, auto_connect_client=False)) - servers.append(server) clients.append(client) round_robin = LoadBalancerRoundRobin(clients) diff --git a/tests/rsocket/test_request_response.py b/tests/rsocket/test_request_response.py index 74ec7d86..6db97ba0 100644 --- a/tests/rsocket/test_request_response.py +++ b/tests/rsocket/test_request_response.py @@ -7,13 +7,13 @@ from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler -from tests.rsocket.helpers import future_from_request +from tests.rsocket.helpers import future_from_payload async def test_request_response_awaitable_wrapper(pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) server, client = pipe server._handler = Handler(server) @@ -25,7 +25,7 @@ async def request_response(self, request: Payload): async def test_request_response_repeated(pipe): class Handler(BaseRequestHandler): async def request_response(self, request: Payload): - return future_from_request(request) + return future_from_payload(request) server, client = pipe server._handler = Handler(server) diff --git a/tests/rsocket/test_request_stream.py b/tests/rsocket/test_request_stream.py index 78aa7028..d1ddd21c 100644 --- a/tests/rsocket/test_request_stream.py +++ b/tests/rsocket/test_request_stream.py @@ -8,7 +8,7 @@ from reactivestreams.subscriber import DefaultSubscriber, Subscriber from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket from rsocket.awaitable.collector_subscriber import CollectorSubscriber -from rsocket.exceptions import RSocketValueErrorException +from rsocket.exceptions import RSocketValueError from rsocket.frame_helpers import ensure_bytes from rsocket.helpers import DefaultPublisherSubscription from rsocket.payload import Payload @@ -52,7 +52,7 @@ async def test_request_stream_prevent_negative_initial_request_n(pipe: Tuple[RSo initial_request_n): server, client = pipe - with pytest.raises(RSocketValueErrorException): + with pytest.raises(RSocketValueError): client.request_stream(Payload()).initial_request_n(initial_request_n) diff --git a/tests/rsocket/test_resume_unsupported.py b/tests/rsocket/test_resume_unsupported.py index a4808e1c..a57f7013 100644 --- a/tests/rsocket/test_resume_unsupported.py +++ b/tests/rsocket/test_resume_unsupported.py @@ -15,7 +15,7 @@ @pytest.mark.allow_error_log async def test_setup_resume_unsupported(pipe_tcp_without_auto_connect: Tuple[RSocketServer, RSocketClient]): - server, client = pipe_tcp_without_auto_connect + _, client = pipe_tcp_without_auto_connect received_error_code = None error_received = asyncio.Event() @@ -25,26 +25,27 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): received_error_code = error_code error_received.set() - client.set_handler_using_factory(Handler) - bad_client = MisbehavingRSocket(client._transport) + client.set_handler_factory(Handler) - setup = SetupFrame() - setup.flags_lease = False - setup.flags_resume = True - setup.token_length = 1 - setup.resume_identification_token = b'a' - setup.keep_alive_milliseconds = 123 - setup.max_lifetime_milliseconds = 456 - setup.data_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() - setup.metadata_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() + async with client as connected_client: + transport = await connected_client._current_transport() + bad_client = MisbehavingRSocket(transport) - await bad_client.send_frame(setup) + setup = SetupFrame() + setup.flags_lease = False + setup.flags_resume = True + setup.token_length = 1 + setup.resume_identification_token = b'a' + setup.keep_alive_milliseconds = 123 + setup.max_lifetime_milliseconds = 456 + setup.data_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() + setup.metadata_encoding = WellKnownMimeTypes.APPLICATION_JSON.name.encode() - await error_received.wait() + await bad_client.send_frame(setup) - await client.close() + await error_received.wait() - assert received_error_code == ErrorCode.UNSUPPORTED_SETUP + assert received_error_code == ErrorCode.UNSUPPORTED_SETUP @pytest.mark.allow_error_log @@ -62,7 +63,8 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): client.set_handler_using_factory(Handler) - bad_client = MisbehavingRSocket(client._transport) + transport = await client._current_transport() + bad_client = MisbehavingRSocket(transport) resume = ResumeFrame() resume.token_length = 1 diff --git a/tests/rsocket/test_rsocket.py b/tests/rsocket/test_rsocket.py index 4dad0e85..a0140667 100644 --- a/tests/rsocket/test_rsocket.py +++ b/tests/rsocket/test_rsocket.py @@ -1,15 +1,15 @@ import asyncio import logging from datetime import timedelta -from typing import Callable import pytest from rsocket.error_codes import ErrorCode -from rsocket.exceptions import RSocketProtocolException +from rsocket.exceptions import RSocketProtocolError from rsocket.helpers import create_future from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_internal import RSocketInternal async def test_rsocket_client_closed_without_requests(lazy_pipe): @@ -46,8 +46,8 @@ class ClientHandler(BaseRequestHandler): async def on_keepalive_timeout(self, time_since_last_keepalive: timedelta, - cancel_all_streams: Callable): - cancel_all_streams() + socket: RSocketInternal): + socket.close_all_streams() async with lazy_pipe( client_arguments={ @@ -55,7 +55,7 @@ async def on_keepalive_timeout(self, 'max_lifetime_period': timedelta(seconds=1), 'handler_factory': ClientHandler}, server_arguments={'handler_factory': Handler}) as (server, client): - with pytest.raises(RSocketProtocolException) as exc_info: + with pytest.raises(RSocketProtocolError) as exc_info: await client.request_response(Payload(b'dog', b'cat')) assert exc_info.value.data == 'Server not alive' diff --git a/tests/rx_support/test_rx_support.py b/tests/rx_support/test_rx_support.py index 820fc0a3..87376b82 100644 --- a/tests/rx_support/test_rx_support.py +++ b/tests/rx_support/test_rx_support.py @@ -172,10 +172,12 @@ class Handler(BaseRequestHandler): async def request_response(self, payload: Payload) -> Future: return create_future(Payload(b'Response')) - server, client = pipe_tcp_without_auto_connect - server.set_handler_using_factory(Handler) + server_provider, client = pipe_tcp_without_auto_connect async with RxRSocket(client) as rx_client: + server = await server_provider() + server.set_handler_using_factory(Handler) + received_message = await rx_client.request_response(Payload(b'request text')).pipe( operators.map(lambda payload: payload.data), operators.single()