Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ v0.4.5
- Breaking change: ReactiveX clients will remove empty payload from request_response Observable, resulting in an actually empty Observable
- Bug fix: fixed channel stream being released prematurely when canceled by requester, and responder side still working
- Bug fix: removed cyclic references in RSocketBase which caused old sessions not to be released
- Bug fix: fixed ability for rxpy streams and fragmented responses to send payloads concurrently
- CollectorSubscriber : exposed subscription methods directly instead of relying on internal **subscription** variable
- Reactivex server side request_response allowed to return reactivex.empty(). Library code will replace with empty Payload when needed
- Added EmptyStream for use in stream and channel responses
Expand Down
7 changes: 7 additions & 0 deletions rsocket/async_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import asyncio


async def async_range(count: int):
for i in range(count):
yield i
await asyncio.sleep(0.0)
2 changes: 1 addition & 1 deletion rsocket/awaitable/collector_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def on_next(self, value, is_complete=False):
else:
if self._received_count == self._limit_rate:
self._received_count = 0
self.subscription.request(self._limit_rate)
self.subscription.request(self._limit_rate)

def on_error(self, exception: Exception):
self.error = exception
Expand Down
1 change: 0 additions & 1 deletion rsocket/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,5 @@ def serialize_with_frame_size_header(frame: Frame) -> bytes:
RequestChannelFrame: 10,
}


def get_header_length(frame: FragmentableFrame) -> int:
return frame_header_length[frame.__class__]
21 changes: 17 additions & 4 deletions rsocket/reactivex/back_pressure_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from reactivestreams.publisher import Publisher
from reactivestreams.subscriber import Subscriber
from rsocket.async_helpers import async_range
from rsocket.helpers import DefaultPublisherSubscription
from rsocket.logger import logger
from rsocket.reactivex.subscriber_adapter import SubscriberAdapter
Expand Down Expand Up @@ -66,8 +67,12 @@ async def _aio_next():

try:
while True:
next_n = await request_n_queue.get()
for i in range(next_n):
try:
next_n = await request_n_queue.get()
except RuntimeError:
return

async for i in async_range(next_n):
try:
value = await iterator.__anext__()
observer.on_next(value)
Expand Down Expand Up @@ -100,15 +105,23 @@ def cancel_sender():
async def observable_to_async_event_generator(observable: Observable) -> AsyncGenerator[Notification, None]:
queue = asyncio.Queue()

completed = object()

def on_next(i):
queue.put_nowait(i)

observable.pipe(materialize()).subscribe(
on_next=on_next
on_next=on_next,
on_completed=lambda: queue.put_nowait(completed)
)

while True:
value = await queue.get()

if value is completed:
queue.task_done()
return

yield value
queue.task_done()

Expand All @@ -128,7 +141,7 @@ async def _aio_next():
try:
while True:
next_n = await request_n_queue.get()
for i in range(next_n):
async for i in async_range(next_n):
event = await iterator.__anext__()

if isinstance(event, OnNext):
Expand Down
4 changes: 3 additions & 1 deletion rsocket/rsocket_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ async def _get_next_frame_to_send(self, transport: Transport) -> Frame:
if isinstance(next_frame_source, FrameFragmentMixin):
next_fragment = next_frame_source.get_next_fragment(transport.requires_length_header())

if not next_fragment.flags_follows:
if next_fragment.flags_follows:
self._send_queue.put_nowait(self._send_queue.get_nowait()) # cycle to next frame source in queue
else:
next_frame_source.get_next_fragment(
transport.requires_length_header()) # workaround to clean-up generator.
self._send_queue.get_nowait()
Expand Down
15 changes: 12 additions & 3 deletions rsocket/rx_support/back_pressure_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from reactivestreams.publisher import Publisher
from reactivestreams.subscriber import Subscriber
from rsocket.async_helpers import async_range
from rsocket.helpers import DefaultPublisherSubscription
from rsocket.logger import logger
from rsocket.rx_support.subscriber_adapter import SubscriberAdapter
Expand Down Expand Up @@ -68,7 +69,7 @@ async def _aio_next():
try:
while True:
next_n = await request_n_queue.get()
for i in range(next_n):
async for i in async_range(next_n):
try:
value = await iterator.__anext__()
observer.on_next(value)
Expand Down Expand Up @@ -101,15 +102,23 @@ def cancel_sender():
async def observable_to_async_event_generator(observable: Observable) -> AsyncGenerator[Notification, None]:
queue = asyncio.Queue()

completed = object()

def on_next(i):
queue.put_nowait(i)

observable.pipe(materialize()).subscribe(
on_next=on_next
on_next=on_next,
on_completed=lambda: queue.put_nowait(completed)
)

while True:
value = await queue.get()

if value is completed:
queue.task_done()
return

yield value
queue.task_done()

Expand All @@ -129,7 +138,7 @@ async def _aio_next():
try:
while True:
next_n = await request_n_queue.get()
for i in range(next_n):
async for i in async_range(next_n):
event = await iterator.__anext__()

if isinstance(event, OnNext):
Expand Down
3 changes: 3 additions & 0 deletions rsocket/stream_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rsocket.error_codes import ErrorCode
from rsocket.exceptions import RSocketStreamAllocationFailure, RSocketStreamIdInUse
from rsocket.frame import CONNECTION_STREAM_ID, Frame, ErrorFrame
from rsocket.logger import logger
from rsocket.streams.stream_handler import StreamHandler

MAX_STREAM_ID = 0x7FFFFFFF
Expand Down Expand Up @@ -33,6 +34,7 @@ def _increment_stream_id(self):
self._current_stream_id = (self._current_stream_id + 2) & self._maximum_stream_id

def finish_stream(self, stream_id: int):
logger().debug('Finishing stream: %s', stream_id)
self._streams.pop(stream_id, None)

def register_stream(self, stream_id: int, handler: StreamHandler):
Expand All @@ -54,6 +56,7 @@ def handle_stream(self, frame: Frame) -> bool:
return False

def stop_all_streams(self, error_code=ErrorCode.CANCELED, data=b''):
logger().debug('Stopping all streams')
for stream_id, stream in list(self._streams.items()):
frame = ErrorFrame()
frame.stream_id = stream_id
Expand Down
3 changes: 2 additions & 1 deletion rsocket/streams/stream_from_async_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import AsyncGenerator, Tuple

from rsocket.async_helpers import async_range
from rsocket.payload import Payload
from rsocket.streams.exceptions import FinishedIterator
from rsocket.streams.stream_from_generator import StreamFromGenerator
Expand All @@ -11,7 +12,7 @@ async def _start_generator(self):

async def _generate_next_n(self, n: int) -> AsyncGenerator[Tuple[Payload, bool], None]:
is_complete_sent = False
for i in range(n):
async for i in async_range(n):
try:
next_value = await self._iteration.__anext__()
is_complete_sent = next_value[1]
Expand Down
3 changes: 2 additions & 1 deletion rsocket/streams/stream_from_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import AsyncGenerator, Tuple, Optional, Callable, Generator

from reactivestreams.subscriber import Subscriber
from rsocket.async_helpers import async_range
from rsocket.helpers import DefaultPublisherSubscription
from rsocket.logger import logger
from rsocket.payload import Payload
Expand Down Expand Up @@ -71,7 +72,7 @@ async def queue_next_n(self):

async def _generate_next_n(self, n: int) -> AsyncGenerator[Tuple[Payload, bool], None]:
is_complete_sent = False
for i in range(n):
async for i in async_range(n):
next_value = next(self._iteration, _finished_iterator)

if next_value is _finished_iterator:
Expand Down
69 changes: 69 additions & 0 deletions tests/rsocket/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
from typing import Tuple, Optional

from rsocket.async_helpers import async_range
from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket
from rsocket.frame_helpers import ensure_bytes
from rsocket.helpers import utf8_decode, create_future
from rsocket.payload import Payload
from rsocket.request_handler import BaseRequestHandler
from rsocket.rsocket_client import RSocketClient
from rsocket.rsocket_server import RSocketServer
from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator
from tests.tools.helpers import measure_time


async def test_concurrent_streams(pipe: Tuple[RSocketServer, RSocketClient]):
class Handler(BaseRequestHandler):

def __init__(self, server_done: Optional[asyncio.Event] = None):
self._server_done = server_done

async def request_stream(self, payload: Payload):
count = int(utf8_decode(payload.data))

async def generator():
async for index in async_range(count):
yield Payload(ensure_bytes('Feed Item: {}/{}'.format(index, count))), index == count - 1

return StreamFromAsyncGenerator(generator)

server, client = pipe

server.set_handler_using_factory(Handler)

request_1 = asyncio.create_task(measure_time(AwaitableRSocket(client).request_stream(Payload(b'2000'))))

request_2 = asyncio.create_task(measure_time(AwaitableRSocket(client).request_stream(Payload(b'10'))))

results = (await request_1, await request_2)

print(results)
delta = abs(results[0].delta - results[1].delta)

assert len(results[0].result) == 2000
assert len(results[1].result) == 10
assert delta > 0.8


async def test_concurrent_fragmented_responses(lazy_pipe_tcp): # check problems with quic and http3 frame boundary
class Handler(BaseRequestHandler):
async def request_response(self, request: Payload):
data = 'a' * 100 * int(utf8_decode(request.data))
return create_future(Payload(ensure_bytes(data)))

async with lazy_pipe_tcp(
server_arguments={'handler_factory': Handler, 'fragment_size_bytes': 100},
client_arguments={'fragment_size_bytes': 100}) as (server, client):
request_1 = asyncio.create_task(measure_time(client.request_response(Payload(b'10000'))))

request_2 = asyncio.create_task(measure_time(client.request_response(Payload(b'10'))))

results = (await request_1, await request_2)

print(results[0].delta, results[1].delta)
delta = abs(results[0].delta - results[1].delta)

assert len(results[0].result.data) == 10000 * 100
assert len(results[1].result.data) == 10 * 100
assert delta > 0.8
9 changes: 9 additions & 0 deletions tests/rsocket/test_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,12 @@ class S(str):
del a

assert len(d) == 0


async def test_range():
async def loop(ii):
for i in range(100):
await asyncio.sleep(0)
print(ii + str(i))

await asyncio.gather(loop('a'), loop('b'))
50 changes: 50 additions & 0 deletions tests/test_reactivex/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
from typing import Tuple, Optional

import reactivex
from reactivex import operators

from rsocket.frame_helpers import ensure_bytes
from rsocket.helpers import utf8_decode
from rsocket.payload import Payload
from rsocket.reactivex.reactivex_client import ReactiveXClient
from rsocket.reactivex.reactivex_handler import BaseReactivexHandler
from rsocket.reactivex.reactivex_handler_adapter import reactivex_handler_factory
from rsocket.rsocket_client import RSocketClient
from rsocket.rsocket_server import RSocketServer
from tests.tools.helpers import measure_time


class Handler(BaseReactivexHandler):

def __init__(self, server_done: Optional[asyncio.Event] = None):
self._server_done = server_done

async def request_stream(self, payload: Payload):
count = int(utf8_decode(payload.data))
return reactivex.from_iterable(
(Payload(ensure_bytes('Feed Item: {}/{}'.format(index, count))) for index in range(count)))


async def test_concurrent_streams(pipe: Tuple[RSocketServer, RSocketClient]):
server, client = pipe

server.set_handler_using_factory(reactivex_handler_factory(Handler))

request_1 = asyncio.create_task(measure_time(ReactiveXClient(client).request_stream(Payload(b'2000')).pipe(
operators.map(lambda payload: payload.data),
operators.do_action(on_next=lambda x: print(x)),
operators.to_list()
)))

request_2 = asyncio.create_task(measure_time(ReactiveXClient(client).request_stream(Payload(b'10')).pipe(
operators.map(lambda payload: payload.data),
operators.do_action(on_next=lambda x: print(x)),
operators.to_list()
)))

results = (await request_1, await request_2)

delta = abs(results[0].delta - results[1].delta)

assert delta > 0.8
4 changes: 3 additions & 1 deletion tests/test_reactivex/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# (10, 10, 10), # fixme: failing on python 3.10
# (0, 10, 0), # operators.take(0) is problematic
))
@pytest.mark.skip
async def test_helper(request_n, generate_n, expected_n):
async def generator():
for i in range(generate_n):
Expand All @@ -31,4 +32,5 @@ async def generator():

assert len(result) == expected_n

await asyncio.sleep(1) # wait for task to finish
# await asyncio.sleep(1) # wait for task to finish

2 changes: 1 addition & 1 deletion tests/tools/fixtures_aioquic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.aioquic_transport import rsocket_connect, rsocket_serve
from tests.rsocket.helpers import assert_no_open_streams
from tests.tools.herlpers import quic_client_configuration
from tests.tools.helpers import quic_client_configuration


@asynccontextmanager
Expand Down
16 changes: 16 additions & 0 deletions tests/tools/herlpers.py → tests/tools/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Awaitable, Any

from aioquic.quic.configuration import QuicConfiguration
from cryptography.hazmat.primitives import serialization

Expand All @@ -10,3 +14,15 @@ def quic_client_configuration(certificate, **kwargs):
ca_data = certificate.public_bytes(serialization.Encoding.PEM)
client_configuration.load_verify_locations(cadata=ca_data, cafile=None)
return client_configuration


@dataclass
class MeasureTime:
result: Any
delta: float


async def measure_time(coroutine: Awaitable) -> MeasureTime:
start = datetime.now()
result = await coroutine
return MeasureTime(result, (datetime.now() - start).total_seconds())
2 changes: 1 addition & 1 deletion tests/tools/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aioquic.h3.connection import H3_ALPN, ErrorCode

from rsocket.transports.http3_transport import Http3TransportWebsocket, RSocketHttp3ClientProtocol
from tests.tools.herlpers import quic_client_configuration
from tests.tools.helpers import quic_client_configuration


@asynccontextmanager
Expand Down