diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a14e1c84..e5da5f98 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,7 @@ v0.4.0 - fragment size now includes frame header and length. - Added checking fragment size limit (minimum 64) as in java implementation - Updated examples +- Added reactivex (RxPy version 4) wrapper client v0.3.0 ====== diff --git a/README.md b/README.md index 96a507a6..f99b06c9 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ pip install rsocket or install any of the extras: -* rx +* rx (RxPy3) +* reactivex (RxPy4) * aiohttp * quart * uic @@ -20,7 +21,7 @@ or install any of the extras: Example: ```shell -pip install --pre rsocket[rx] +pip install --pre rsocket[reactivex] ``` Alternatively, download the source code, build a package: diff --git a/requirements.txt b/requirements.txt index f4e3f04e..a7d1da0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ aiohttp==3.8.1 quart==0.17.0 coveralls==3.3.1 aioquic==0.9.19 +reactivex==4.0.4 \ No newline at end of file diff --git a/rsocket/reactivex/__init__.py b/rsocket/reactivex/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rsocket/reactivex/back_pressure_publisher.py b/rsocket/reactivex/back_pressure_publisher.py new file mode 100644 index 00000000..13912208 --- /dev/null +++ b/rsocket/reactivex/back_pressure_publisher.py @@ -0,0 +1,77 @@ +import asyncio +from typing import Optional + +import reactivex +from reactivex import Observable, Observer +from reactivex.notification import OnNext, OnError, OnCompleted +from reactivex.operators import materialize +from reactivex.subject import Subject + +from reactivestreams.subscriber import Subscriber +from rsocket.helpers import DefaultPublisherSubscription +from rsocket.logger import logger +from rsocket.reactivex.subscriber_adapter import SubscriberAdapter + + +async def observable_to_async_event_generator(observable: Observable): + queue = asyncio.Queue() + + def on_next(i): + queue.put_nowait(i) + + observable.pipe(materialize()).subscribe( + on_next=on_next + ) + + while True: + value = await queue.get() + yield value + queue.task_done() + + +def from_aiter(iterator, feedback: Optional[Observable] = None): + # noinspection PyUnusedLocal + def on_subscribe(observer: Observer, scheduler): + async def _aio_next(): + try: + event = await iterator.__anext__() + + if isinstance(event, OnNext): + observer.on_next(event.value) + elif isinstance(event, OnError): + observer.on_error(event.exception) + elif isinstance(event, OnCompleted): + observer.on_completed() + except StopAsyncIteration: + pass + except Exception as exception: + logger().error(str(exception), exc_info=True) + observer.on_error(exception) + + def create_next_task(): + asyncio.create_task(_aio_next()) + + return feedback.subscribe( + on_next=lambda i: create_next_task() + ) + + return reactivex.create(on_subscribe) + + +class BackPressurePublisher(DefaultPublisherSubscription): + def __init__(self, wrapped_observable: Observable): + self._wrapped_observable = wrapped_observable + self._feedback = None + + def subscribe(self, subscriber: Subscriber): + super().subscribe(subscriber) + self._feedback = Subject() + async_iterator = observable_to_async_event_generator(self._wrapped_observable).__aiter__() + from_aiter(async_iterator, self._feedback).subscribe(SubscriberAdapter(subscriber)) + + def request(self, n: int): + for i in range(n): + self._feedback.on_next(True) + + def cancel(self): + self._feedback.on_completed() diff --git a/rsocket/reactivex/from_rsocket_publisher.py b/rsocket/reactivex/from_rsocket_publisher.py new file mode 100644 index 00000000..523e113b --- /dev/null +++ b/rsocket/reactivex/from_rsocket_publisher.py @@ -0,0 +1,92 @@ +import asyncio +import functools + +import reactivex +from reactivex import Observable, Observer +from reactivex.disposable import Disposable + +from reactivestreams.publisher import Publisher +from reactivestreams.subscriber import Subscriber +from reactivestreams.subscription import Subscription +from rsocket.logger import logger + + +class RxSubscriber(Subscriber): + def __init__(self, observer, limit_rate: int): + self.limit_rate = limit_rate + self.observer = observer + self._received_messages = 0 + self.done = asyncio.Event() + self.get_next_n = asyncio.Event() + self.subscription = None + + def on_subscribe(self, subscription: Subscription): + self.subscription = subscription + + def on_next(self, value, is_complete=False): + self._received_messages += 1 + self.observer.on_next(value) + if is_complete: + self.observer.on_completed() + self._finish() + + else: + if self._received_messages == self.limit_rate: + self._received_messages = 0 + self.get_next_n.set() + + def _finish(self): + self.done.set() + + def on_error(self, exception: Exception): + self.observer.on_error(exception) + self._finish() + + def on_complete(self): + self.observer.on_completed() + self._finish() + + +async def _aio_sub(publisher: Publisher, subscriber: RxSubscriber, observer: Observer, loop): + try: + publisher.subscribe(subscriber) + await subscriber.done.wait() + + except asyncio.CancelledError: + if not subscriber.done.is_set(): + subscriber.subscription.cancel() + except Exception as exception: + loop.call_soon(functools.partial(observer.on_error, exception)) + + +async def _trigger_next_request_n(subscriber, limit_rate): + try: + while True: + await subscriber.get_next_n.wait() + subscriber.subscription.request(limit_rate) + subscriber.get_next_n.clear() + except asyncio.CancelledError: + logger().debug('Asyncio task canceled: trigger_next_request_n') + + +def from_rsocket_publisher(publisher: Publisher, limit_rate=5) -> Observable: + loop = asyncio.get_event_loop() + + # noinspection PyUnusedLocal + def on_subscribe(observer: Observer, scheduler): + subscriber = RxSubscriber(observer, limit_rate) + + get_next_task = asyncio.create_task( + _trigger_next_request_n(subscriber, limit_rate) + ) + task = asyncio.create_task( + _aio_sub(publisher, subscriber, observer, loop) + ) + + def dispose(): + get_next_task.cancel() + task.cancel() + + return Disposable(dispose) + + return reactivex.create(on_subscribe) diff --git a/rsocket/reactivex/reactivex_client.py b/rsocket/reactivex/reactivex_client.py new file mode 100644 index 00000000..c4face1e --- /dev/null +++ b/rsocket/reactivex/reactivex_client.py @@ -0,0 +1,57 @@ +from asyncio import Future + +from typing import Optional, cast + +import reactivex +from reactivex import Observable + +from rsocket.frame import MAX_REQUEST_N +from rsocket.payload import Payload +from rsocket.rsocket import RSocket +from rsocket.reactivex.back_pressure_publisher import BackPressurePublisher +from rsocket.reactivex.from_rsocket_publisher import from_rsocket_publisher + + +class ReactiveXClient: + def __init__(self, rsocket: RSocket): + self._rsocket = rsocket + + def request_stream(self, request: Payload, request_limit: int = MAX_REQUEST_N) -> Observable: + response_publisher = self._rsocket.request_stream(request).initial_request_n(request_limit) + return from_rsocket_publisher(response_publisher, request_limit) + + def request_response(self, request: Payload) -> Observable: + return reactivex.from_future(cast(Future, self._rsocket.request_response(request))) + + def request_channel(self, + request: Payload, + request_limit: int = MAX_REQUEST_N, + observable: Optional[Observable] = None) -> Observable: + if observable is not None: + local_publisher = BackPressurePublisher(observable) + else: + local_publisher = None + + response_publisher = self._rsocket.request_channel( + request, local_publisher + ).initial_request_n(request_limit) + return from_rsocket_publisher(response_publisher, request_limit) + + def fire_and_forget(self, request: Payload) -> Observable: + return reactivex.from_future(cast(Future, self._rsocket.fire_and_forget(request))) + + def metadata_push(self, metadata: bytes) -> Observable: + return reactivex.from_future(cast(Future, self._rsocket.metadata_push(metadata))) + + async def connect(self): + return await self._rsocket.connect() + + async def close(self): + await self._rsocket.close() + + async def __aenter__(self): + await self._rsocket.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._rsocket.__aexit__(exc_type, exc_val, exc_tb) diff --git a/rsocket/reactivex/subscriber_adapter.py b/rsocket/reactivex/subscriber_adapter.py new file mode 100644 index 00000000..164545b4 --- /dev/null +++ b/rsocket/reactivex/subscriber_adapter.py @@ -0,0 +1,17 @@ +from reactivex.abc import ObserverBase + +from reactivestreams.subscriber import Subscriber + + +class SubscriberAdapter(ObserverBase): + def __init__(self, subscriber: Subscriber): + self._subscriber = subscriber + + def on_next(self, value): + self._subscriber.on_next(value) + + def on_error(self, error): + self._subscriber.on_error(error) + + def on_completed(self): + self._subscriber.on_complete() diff --git a/setup.py b/setup.py index 0f58303d..feb32374 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ python_requires='>=3.8', extras_require={ 'rx': {'Rx >= 3.0.0'}, + 'reactivex': {'reactivex >= 4.0.0'}, 'aiohttp': {'aiohttp >= 3.0.0'}, 'quart': {'quart >= 0.15.0'}, 'quic': {'aioquic >= 0.9.0'} diff --git a/tests/rx_support/__init__.py b/tests/rx_support/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_reactivex/__init__.py b/tests/test_reactivex/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_reactivex/test_reactivex_canceled.py b/tests/test_reactivex/test_reactivex_canceled.py new file mode 100644 index 00000000..167f88b0 --- /dev/null +++ b/tests/test_reactivex/test_reactivex_canceled.py @@ -0,0 +1,182 @@ +import asyncio +from typing import Tuple, AsyncGenerator, Optional + +import pytest +import reactivex +from reactivex import operators + +from reactivestreams.publisher import Publisher +from reactivestreams.subscriber import Subscriber, DefaultSubscriber +from reactivestreams.subscription import Subscription +from rsocket.error_codes import ErrorCode +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.reactivex.reactivex_client import ReactiveXClient +from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator + + +@pytest.mark.parametrize('take_only_n', ( + 1, + 2, + 5, +)) +async def test_rx_support_request_stream_take_only_n(pipe: Tuple[RSocketServer, RSocketClient], + take_only_n): + server, client = pipe + maximum_message_count = 4 + wait_for_server_finish = asyncio.Event() + items_generated = 0 + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + nonlocal items_generated + for x in range(maximum_message_count): + items_generated += 1 + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == maximum_message_count - 1 + + class Handler(BaseRequestHandler): + async def request_stream(self, payload: Payload) -> Publisher: + def set_server_finished(): wait_for_server_finish.set() + + return StreamFromAsyncGenerator(generator, + on_cancel=set_server_finished, + on_complete=set_server_finished) + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + received_messages = await rx_client.request_stream(Payload(b'request text'), + request_limit=1).pipe( + operators.map(lambda payload: payload.data), + operators.take(take_only_n), + operators.to_list() + ) + + await wait_for_server_finish.wait() + + maximum_message_received = min(maximum_message_count, take_only_n) + + assert len(received_messages) == maximum_message_received, 'Received message count wrong' + assert items_generated == maximum_message_received, 'Received message count wrong' + + for i in range(maximum_message_received): + assert received_messages[i] == ('Feed Item: %d' % i).encode() + + +@pytest.mark.parametrize('take_only_n', ( + # 0, + 1, + 2, + 6, +)) +async def test_rx_support_request_channel_response_take_only_n(pipe: Tuple[RSocketServer, RSocketClient], + take_only_n): + server, client = pipe + + maximum_message_count = 4 + wait_for_server_finish = asyncio.Event() + items_generated = 0 + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + nonlocal items_generated + for x in range(maximum_message_count): + items_generated += 1 + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == maximum_message_count - 1 + + class Handler(BaseRequestHandler): + async def request_channel(self, payload: Payload) -> Tuple[Optional[Publisher], Optional[Subscriber]]: + def set_server_finished(): wait_for_server_finish.set() + + return StreamFromAsyncGenerator(generator, + on_cancel=set_server_finished, + on_complete=set_server_finished), None + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + received_messages = await rx_client.request_channel( + Payload(b'request text'), + request_limit=1 + ).pipe( + operators.map(lambda payload: payload.data), + operators.take(take_only_n), + operators.to_list() + ) + + if take_only_n > 0: + await wait_for_server_finish.wait() + + maximum_message_received = min(maximum_message_count, take_only_n) + + assert len(received_messages) == maximum_message_received, 'Received message count wrong' + assert items_generated == maximum_message_received, 'Received message count wrong' + + for i in range(maximum_message_received): + assert received_messages[i] == ('Feed Item: %d' % i).encode() + + +@pytest.mark.parametrize('take_only_n', ( + 1, + 2, + 6, +)) +async def test_rx_support_request_channel_server_take_only_n(pipe: Tuple[RSocketServer, RSocketClient], + take_only_n): + server, client = pipe + received_messages = [] + items_generated = 0 + maximum_message_count = 3 + wait_for_server_finish = asyncio.Event() + + class Handler(BaseRequestHandler, DefaultSubscriber): + + def on_next(self, value: Payload, is_complete=False): + received_messages.append(value) + if len(received_messages) < take_only_n: + self.subscription.request(1) + else: + self.subscription.cancel() + wait_for_server_finish.set() + + def on_complete(self): + wait_for_server_finish.set() + + async def on_error(self, error_code: ErrorCode, payload: Payload): + wait_for_server_finish.set() + + def on_subscribe(self, subscription: Subscription): + super().on_subscribe(subscription) + subscription.request(1) + + async def request_channel(self, payload: Payload) -> Tuple[Optional[Publisher], Optional[Subscriber]]: + return None, self + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + def generator(): + nonlocal items_generated + for x in range(maximum_message_count): + items_generated += 1 + yield Payload('Feed Item: {}'.format(x).encode('utf-8')) + + await rx_client.request_channel( + Payload(b'request text'), + observable=reactivex.from_iterable(generator()) + ).pipe( + operators.to_list() + ) + + await wait_for_server_finish.wait() + + maximum_message_received = min(maximum_message_count, take_only_n) + + # assert items_generated == maximum_message_received # todo: Stop async generator on cancel from server requester + + assert len(received_messages) == maximum_message_received + + for i in range(maximum_message_received): + assert received_messages[i].data == ('Feed Item: %d' % i).encode() diff --git a/tests/test_reactivex/test_reactivex_error.py b/tests/test_reactivex/test_reactivex_error.py new file mode 100644 index 00000000..0b465dfc --- /dev/null +++ b/tests/test_reactivex/test_reactivex_error.py @@ -0,0 +1,112 @@ +import asyncio +from typing import Tuple, AsyncGenerator, Optional + +import pytest +import reactivex +from reactivex import operators, Observer +from reactivex.scheduler.scheduler import Scheduler +from reactivex.disposable import Disposable + +from reactivestreams.publisher import Publisher +from reactivestreams.subscriber import Subscriber, DefaultSubscriber +from reactivestreams.subscription import Subscription +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.reactivex.reactivex_client import ReactiveXClient +from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator + + +@pytest.mark.parametrize('success_count, request_limit', ( + (0, 2), + (2, 2), + (3, 2), +)) +async def test_rx_support_request_stream_with_error(pipe: Tuple[RSocketServer, RSocketClient], + success_count, + request_limit): + server, client = pipe + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + for x in range(success_count): + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), False + + raise Exception('Some error from responder') + + class Handler(BaseRequestHandler): + async def request_stream(self, payload: Payload) -> Publisher: + return StreamFromAsyncGenerator(generator) + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + with pytest.raises(Exception): + await rx_client.request_stream( + Payload(b'request text'), + request_limit=request_limit + ).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + +@pytest.mark.parametrize('success_count, request_limit', ( + (0, 2), + (2, 2), + (3, 2), +)) +async def test_rx_support_request_channel_with_error_from_requester( + pipe: Tuple[RSocketServer, RSocketClient], + success_count, + request_limit): + server, client = pipe + responder_received_error = asyncio.Event() + server_received_messages = [] + received_error = None + + class ResponderSubscriber(DefaultSubscriber): + + def on_subscribe(self, subscription: Subscription): + super().on_subscribe(subscription) + self.subscription.request(1) + + def on_next(self, value, is_complete=False): + if len(value.data) > 0: + server_received_messages.append(value.data) + self.subscription.request(1) + + def on_error(self, exception: Exception): + nonlocal received_error + received_error = exception + responder_received_error.set() + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + for x in range(success_count): + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == success_count - 1 + + class Handler(BaseRequestHandler): + async def request_channel(self, payload: Payload) -> Tuple[Optional[Publisher], Optional[Subscriber]]: + return StreamFromAsyncGenerator(generator), ResponderSubscriber() + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + def test_observable(observer: Observer, scheduler: Optional[Scheduler]): + observer.on_error(Exception('Some error')) + return Disposable() + + await rx_client.request_channel( + Payload(b'request text'), + observable=reactivex.create(test_observable), + request_limit=request_limit + ).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + await responder_received_error.wait() + + assert str(received_error) == 'Some error' diff --git a/tests/test_reactivex/test_reactivex_support.py b/tests/test_reactivex/test_reactivex_support.py new file mode 100644 index 00000000..0dba9742 --- /dev/null +++ b/tests/test_reactivex/test_reactivex_support.py @@ -0,0 +1,228 @@ +import asyncio +from asyncio import Future +from typing import Tuple, AsyncGenerator, Optional + +import reactivex +from reactivex import operators + +from reactivestreams.publisher import Publisher +from reactivestreams.subscriber import Subscriber, DefaultSubscriber +from reactivestreams.subscription import Subscription +from rsocket.helpers import create_future, DefaultPublisherSubscription +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.reactivex.reactivex_client import ReactiveXClient +from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator + + +async def test_rx_support_request_stream_properly_finished(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + for x in range(3): + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == 2 + + class Handler(BaseRequestHandler): + async def request_stream(self, payload: Payload) -> Publisher: + return StreamFromAsyncGenerator(generator) + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + received_messages = await rx_client.request_stream(Payload(b'request text'), + request_limit=2).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + assert len(received_messages) == 3 + assert received_messages[0] == b'Feed Item: 0' + assert received_messages[1] == b'Feed Item: 1' + assert received_messages[2] == b'Feed Item: 2' + + +async def test_rx_support_request_stream_immediate_complete(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + + class SimpleCompleted(DefaultPublisherSubscription): + + def request(self, n: int): + self._subscriber.on_complete() + + class Handler(BaseRequestHandler): + async def request_stream(self, payload: Payload) -> Publisher: + return SimpleCompleted() + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + result = await rx_client.request_stream( + Payload(b'request text'), + request_limit=2 + ).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + assert len(result) == 0 + + +async def test_rx_support_request_response_properly_finished(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + + class Handler(BaseRequestHandler): + async def request_response(self, payload: Payload) -> Future: + return create_future(Payload(b'Response')) + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + received_message = await rx_client.request_response(Payload(b'request text')).pipe( + operators.map(lambda payload: payload.data), + operators.single() + ) + + assert received_message == b'Response' + + +async def test_rx_support_request_channel_properly_finished(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + server_received_messages = [] + + responder_received_all = asyncio.Event() + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + for x in range(3): + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == 2 + + class ResponderSubscriber(DefaultSubscriber): + + def on_subscribe(self, subscription: Subscription): + super().on_subscribe(subscription) + self.subscription.request(1) + + def on_next(self, value, is_complete=False): + if len(value.data) > 0: + server_received_messages.append(value.data) + + self.subscription.request(1) + + def on_complete(self): + responder_received_all.set() + + class Handler(BaseRequestHandler): + async def request_channel(self, payload: Payload) -> Tuple[Optional[Publisher], Optional[Subscriber]]: + return StreamFromAsyncGenerator(generator), ResponderSubscriber() + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + sent_messages = [b'1', b'2', b'3'] + sent_payloads = [Payload(data) for data in sent_messages] + received_messages = await rx_client.request_channel(Payload(b'request text'), + observable=reactivex.from_list(sent_payloads), + request_limit=2).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + await responder_received_all.wait() + + assert server_received_messages == sent_messages + + assert len(received_messages) == 3 + assert received_messages[0] == b'Feed Item: 0' + assert received_messages[1] == b'Feed Item: 1' + assert received_messages[2] == b'Feed Item: 2' + + +async def test_rx_support_request_channel_response_only_properly_finished(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + + async def generator() -> AsyncGenerator[Tuple[Payload, bool], None]: + for x in range(3): + yield Payload('Feed Item: {}'.format(x).encode('utf-8')), x == 2 + + class Handler(BaseRequestHandler): + async def request_channel(self, payload: Payload) -> Tuple[Optional[Publisher], Optional[Subscriber]]: + return StreamFromAsyncGenerator(generator), None + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + received_messages = await rx_client.request_channel(Payload(b'request text'), + request_limit=2).pipe( + operators.map(lambda payload: payload.data), + operators.to_list() + ) + + assert len(received_messages) == 3 + assert received_messages[0] == b'Feed Item: 0' + assert received_messages[1] == b'Feed Item: 1' + assert received_messages[2] == b'Feed Item: 2' + + +async def test_rx_rsocket_context_manager(pipe_tcp_without_auto_connect): + class Handler(BaseRequestHandler): + async def request_response(self, payload: Payload) -> Future: + return create_future(Payload(b'Response')) + + server_provider, client = pipe_tcp_without_auto_connect + + async with ReactiveXClient(client) as rx_client: + server = await server_provider() + server.set_handler_using_factory(Handler) + + received_message = await rx_client.request_response(Payload(b'request text')).pipe( + operators.map(lambda payload: payload.data), + operators.single() + ) + + assert received_message == b'Response' + + +async def test_rx_support_metadata_push(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + received_item_event = asyncio.Event() + received_item = None + + class Handler(BaseRequestHandler): + async def on_metadata_push(self, payload: Payload): + nonlocal received_item + received_item = payload.metadata + received_item_event.set() + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + await rx_client.metadata_push(b'request text') + + await received_item_event.wait() + + assert received_item == b'request text' + + +async def test_rx_support_fire_and_forget(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + received_item_event = asyncio.Event() + received_item = None + + class Handler(BaseRequestHandler): + async def request_fire_and_forget(self, payload: Payload): + nonlocal received_item + received_item = payload.data + received_item_event.set() + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + await rx_client.fire_and_forget(Payload(b'request text')) + + await received_item_event.wait() + + assert received_item == b'request text' diff --git a/tests/test_reactivex/test_reactivex_timeout.py b/tests/test_reactivex/test_reactivex_timeout.py new file mode 100644 index 00000000..7a21cef5 --- /dev/null +++ b/tests/test_reactivex/test_reactivex_timeout.py @@ -0,0 +1,79 @@ +import asyncio +from asyncio import Future +from typing import Tuple + +import pytest +from reactivex import operators + +from reactivestreams.publisher import Publisher +from rsocket.helpers import create_future, DefaultPublisherSubscription +from rsocket.payload import Payload +from rsocket.request_handler import BaseRequestHandler +from rsocket.rsocket_client import RSocketClient +from rsocket.rsocket_server import RSocketServer +from rsocket.reactivex.reactivex_client import ReactiveXClient + + +async def test_rx_support_request_stream_cancel_on_timeout(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + cancel_done = asyncio.Event() + stream_messages_sent_count = 0 + + class Handler(BaseRequestHandler, DefaultPublisherSubscription): + + async def delayed_stream(self): + nonlocal stream_messages_sent_count + try: + await asyncio.sleep(3) + self._subscriber.on_next(Payload(b'success')) + stream_messages_sent_count += 1 + except asyncio.CancelledError: + cancel_done.set() + + def cancel(self): + self._task.cancel() + + def request(self, n: int): + self._task = asyncio.create_task(self.delayed_stream()) + + async def request_stream(self, payload: Payload) -> Publisher: + return self + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + with pytest.raises(Exception): + await asyncio.wait_for(rx_client.request_stream( + Payload(b'request text') + ).pipe( + operators.to_list() + ), 2) + + await cancel_done.wait() + + assert stream_messages_sent_count == 0 + + +async def test_rx_support_request_response_cancel_on_timeout(pipe: Tuple[RSocketServer, RSocketClient]): + server, client = pipe + response_sent = False + + class Handler(BaseRequestHandler): + + async def request_response(self, payload: Payload) -> Future: + nonlocal response_sent + await asyncio.sleep(3) + response_sent = True + return create_future(Payload(b'response')) + + server.set_handler_using_factory(Handler) + + rx_client = ReactiveXClient(client) + + with pytest.raises(Exception): + await asyncio.wait_for(rx_client.request_response( + Payload(b'request text') + ), 2) + + assert not response_sent