Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ all the examples
| server_with_lease.py | | | ClientWithLease |
| server_with_routing.py | | client_with_routing.py | Client |
| server_with_routing.py | | client_rx.py | |
| server_with_routing.py | | client_reconnect.py | |
| | Server | run_against_example_java_server.py | |
| server_quart_websocket.py | | client_websocket.py | |
| server_aiohttp_websocket.py | | client_websocket.py | |
Expand Down
55 changes: 55 additions & 0 deletions examples/client_reconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import logging
import sys

from rsocket.extensions.helpers import route, composite, authenticate_simple
from rsocket.extensions.mimetypes import WellKnownMimeTypes
from rsocket.payload import Payload
from rsocket.request_handler import BaseRequestHandler
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.tcp import TransportTCP


async def request_response(client: RSocketClient) -> Payload:
payload = Payload(b'The quick brown fox', composite(
route('single_request'),
authenticate_simple('user', '12345')
))

return await client.request_response(payload)


class Handler(BaseRequestHandler):

async def on_connection_lost(self, rsocket: RSocketClient, exception: Exception):
await asyncio.sleep(5)
await rsocket.reconnect()


async def main(server_port):
logging.info('Connecting to server at localhost:%s', server_port)

async def transport_provider(max_reconnect):
for i in range(max_reconnect):
connection = await asyncio.open_connection('localhost', server_port)
yield TransportTCP(*connection)

async with RSocketClient(transport_provider(3),
metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA,
handler_factory=Handler) as client:
result1 = await request_response(client)
assert result1.data == b'single_response'

await asyncio.sleep(10)

result2 = await request_response(client)
assert result2.data == b'single_response'

result3 = await request_response(client)
assert result3.data == b'single_response'


if __name__ == '__main__':
port = sys.argv[1] if len(sys.argv) > 1 else 6565
logging.basicConfig(level=logging.DEBUG)
asyncio.run(main(port))
11 changes: 5 additions & 6 deletions rsocket/rsocket_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
self._lease_publisher = lease_publisher
self._sender_task = None
self._receiver_task = None
self._handler = None
self._handler = self._handler_factory(self)
self._responder_lease = None
self._requester_lease = None
self._is_closing = False
Expand Down Expand Up @@ -115,7 +115,6 @@ def _reset_internals(self):

self._responder_lease = NullLease()
self._stream_control = StreamControl(self._get_first_stream_id())
self._handler = self._handler_factory(self)
self._is_closing = False

def stop_all_streams(self, error_code=ErrorCode.CANCELED, data=b''):
Expand All @@ -139,9 +138,6 @@ def _start_task_if_not_closing(self, task_factory: Callable[[], Coroutine]) -> O
if not self._is_closing:
return asyncio.create_task(task_factory())

def set_handler_factory(self, handler_factory):
self._handler_factory = handler_factory

def set_handler_using_factory(self, handler_factory) -> RequestHandler:
self._handler = handler_factory(self)
return self._handler
Expand Down Expand Up @@ -421,7 +417,10 @@ async def close(self):
async def _close_transport(self):
if self._current_transport().done():
logger().debug('%s: Closing transport', self._log_identifier())
transport = await self._current_transport()
try:
transport = await self._current_transport()
except asyncio.CancelledError:
raise RSocketTransportError()

if transport is not None:
try:
Expand Down
3 changes: 3 additions & 0 deletions rsocket/rsocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self._transport: Optional[Transport] = None
self._next_transport = asyncio.Future()
self._reconnect_task = asyncio.create_task(self._reconnect_listener())
self._keepalive_task = None

super().__init__(handler_factory=handler_factory,
honor_lease=honor_lease,
Expand Down Expand Up @@ -128,6 +129,8 @@ async def _reconnect_listener(self):
logger().debug('%s: Asyncio task canceled: reconnect_listener', self._log_identifier())
except Exception:
logger().error('%s: Reconnect listener', self._log_identifier(), exc_info=True)
finally:
self.stop_all_streams()

async def _keepalive_send_task(self):
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/rsocket/test_resume_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def on_error(self, error_code: ErrorCode, payload: Payload):
received_error_code = error_code
error_received.set()

client.set_handler_factory(Handler)
client.set_handler_using_factory(Handler)

async with client as connected_client:
transport = await connected_client._current_transport()
Expand Down
31 changes: 31 additions & 0 deletions tests/rsocket/test_without_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import asyncio

import pytest

from rsocket.exceptions import RSocketTransportError
from rsocket.logger import logger
from rsocket.request_handler import BaseRequestHandler
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.tcp import TransportTCP


@pytest.mark.allow_error_log()
async def test_connection_never_established(unused_tcp_port: int):
class ClientHandler(BaseRequestHandler):
async def on_connection_lost(self, rsocket, exception: Exception):
logger().info('Test Reconnecting')
await rsocket.reconnect()

async def transport_provider():
try:
for i in range(3):
client_connection = await asyncio.open_connection('localhost', unused_tcp_port)
yield TransportTCP(*client_connection)

except Exception:
logger().error('Client connection error', exc_info=True)
raise

with pytest.raises(RSocketTransportError):
async with RSocketClient(transport_provider(), handler_factory=ClientHandler):
await asyncio.sleep(1)