diff --git a/README.md b/README.md index 422f9739..80ebef82 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ all the examples | server_with_lease.py | | | ClientWithLease | | server_with_routing.py | | client_with_routing.py | Client | | server_with_routing.py | | client_rx.py | | +| server_with_routing.py | | client_reconnect.py | | | | Server | run_against_example_java_server.py | | | server_quart_websocket.py | | client_websocket.py | | | server_aiohttp_websocket.py | | client_websocket.py | | diff --git a/examples/client_reconnect.py b/examples/client_reconnect.py new file mode 100644 index 00000000..f789d722 --- /dev/null +++ b/examples/client_reconnect.py @@ -0,0 +1,55 @@ +import asyncio +import logging +import sys + +from rsocket.extensions.helpers import route, composite, authenticate_simple +from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.transports.tcp import TransportTCP + + +async def request_response(client: RSocketClient) -> Payload: + payload = Payload(b'The quick brown fox', composite( + route('single_request'), + authenticate_simple('user', '12345') + )) + + return await client.request_response(payload) + + +class Handler(BaseRequestHandler): + + async def on_connection_lost(self, rsocket: RSocketClient, exception: Exception): + await asyncio.sleep(5) + await rsocket.reconnect() + + +async def main(server_port): + logging.info('Connecting to server at localhost:%s', server_port) + + async def transport_provider(max_reconnect): + for i in range(max_reconnect): + connection = await asyncio.open_connection('localhost', server_port) + yield TransportTCP(*connection) + + async with RSocketClient(transport_provider(3), + metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA, + handler_factory=Handler) as client: + result1 = await request_response(client) + assert result1.data == b'single_response' + + await asyncio.sleep(10) + + result2 = await request_response(client) + assert result2.data == b'single_response' + + result3 = await request_response(client) + assert result3.data == b'single_response' + + +if __name__ == '__main__': + port = sys.argv[1] if len(sys.argv) > 1 else 6565 + logging.basicConfig(level=logging.DEBUG) + asyncio.run(main(port)) diff --git a/rsocket/rsocket_base.py b/rsocket/rsocket_base.py index 3cee9182..e0c417e1 100644 --- a/rsocket/rsocket_base.py +++ b/rsocket/rsocket_base.py @@ -75,7 +75,7 @@ def __init__(self, self._lease_publisher = lease_publisher self._sender_task = None self._receiver_task = None - self._handler = None + self._handler = self._handler_factory(self) self._responder_lease = None self._requester_lease = None self._is_closing = False @@ -115,7 +115,6 @@ def _reset_internals(self): self._responder_lease = NullLease() self._stream_control = StreamControl(self._get_first_stream_id()) - self._handler = self._handler_factory(self) self._is_closing = False def stop_all_streams(self, error_code=ErrorCode.CANCELED, data=b''): @@ -139,9 +138,6 @@ def _start_task_if_not_closing(self, task_factory: Callable[[], Coroutine]) -> O if not self._is_closing: return asyncio.create_task(task_factory()) - def set_handler_factory(self, handler_factory): - self._handler_factory = handler_factory - def set_handler_using_factory(self, handler_factory) -> RequestHandler: self._handler = handler_factory(self) return self._handler @@ -421,7 +417,10 @@ async def close(self): async def _close_transport(self): if self._current_transport().done(): logger().debug('%s: Closing transport', self._log_identifier()) - transport = await self._current_transport() + try: + transport = await self._current_transport() + except asyncio.CancelledError: + raise RSocketTransportError() if transport is not None: try: diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index eb44e7e2..b5826a12 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -37,6 +37,7 @@ def __init__(self, self._transport: Optional[Transport] = None self._next_transport = asyncio.Future() self._reconnect_task = asyncio.create_task(self._reconnect_listener()) + self._keepalive_task = None super().__init__(handler_factory=handler_factory, honor_lease=honor_lease, @@ -128,6 +129,8 @@ async def _reconnect_listener(self): logger().debug('%s: Asyncio task canceled: reconnect_listener', self._log_identifier()) except Exception: logger().error('%s: Reconnect listener', self._log_identifier(), exc_info=True) + finally: + self.stop_all_streams() async def _keepalive_send_task(self): try: diff --git a/tests/rsocket/test_resume_unsupported.py b/tests/rsocket/test_resume_unsupported.py index c81adf00..71fb3d51 100644 --- a/tests/rsocket/test_resume_unsupported.py +++ b/tests/rsocket/test_resume_unsupported.py @@ -25,7 +25,7 @@ async def on_error(self, error_code: ErrorCode, payload: Payload): received_error_code = error_code error_received.set() - client.set_handler_factory(Handler) + client.set_handler_using_factory(Handler) async with client as connected_client: transport = await connected_client._current_transport() diff --git a/tests/rsocket/test_without_server.py b/tests/rsocket/test_without_server.py new file mode 100644 index 00000000..26f75855 --- /dev/null +++ b/tests/rsocket/test_without_server.py @@ -0,0 +1,31 @@ +import asyncio + +import pytest + +from rsocket.exceptions import RSocketTransportError +from rsocket.logger import logger +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.transports.tcp import TransportTCP + + +@pytest.mark.allow_error_log() +async def test_connection_never_established(unused_tcp_port: int): + class ClientHandler(BaseRequestHandler): + async def on_connection_lost(self, rsocket, exception: Exception): + logger().info('Test Reconnecting') + await rsocket.reconnect() + + async def transport_provider(): + try: + for i in range(3): + client_connection = await asyncio.open_connection('localhost', unused_tcp_port) + yield TransportTCP(*client_connection) + + except Exception: + logger().error('Client connection error', exc_info=True) + raise + + with pytest.raises(RSocketTransportError): + async with RSocketClient(transport_provider(), handler_factory=ClientHandler): + await asyncio.sleep(1)