Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from reactivestreams.subscriber import DefaultSubscriber
from rsocket.helpers import single_transport_provider
from rsocket.payload import Payload
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.tcp import TransportTCP
Expand All @@ -17,7 +18,7 @@ def on_next(self, value, is_complete=False):
async def main():
connection = await asyncio.open_connection('localhost', 6565)

async with RSocketClient(TransportTCP(*connection)) as client:
async with RSocketClient(single_transport_provider(TransportTCP(*connection))) as client:
payload = Payload(b'%Y-%m-%d %H:%M:%S')

async def run_request_response():
Expand Down
6 changes: 4 additions & 2 deletions examples/client_springboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from uuid import uuid4

from reactivestreams.subscriber import DefaultSubscriber
from rsocket.extensions.helpers import composite, route, authenticate_simple
from rsocket.extensions.mimetypes import WellKnownMimeTypes
from rsocket.helpers import single_transport_provider
from rsocket.payload import Payload
from rsocket.extensions.helpers import composite, route, authenticate_simple
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.tcp import TransportTCP

Expand All @@ -24,7 +25,8 @@ async def main():
setup_payload = Payload(
data=str(uuid4()).encode(),
metadata=composite(route('shell-client'), authenticate_simple('user', 'pass')))
async with RSocketClient(TransportTCP(*connection),

async with RSocketClient(single_transport_provider(TransportTCP(*connection)),
setup_payload=setup_payload,
metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA):
await asyncio.sleep(5)
Expand Down
6 changes: 4 additions & 2 deletions examples/client_with_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from reactivestreams.subscriber import Subscriber
from reactivestreams.subscription import Subscription
from rsocket.extensions.helpers import route, composite, authenticate_simple
from rsocket.extensions.mimetypes import WellKnownMimeTypes
from rsocket.fragment import Fragment
from rsocket.helpers import single_transport_provider
from rsocket.payload import Payload
from rsocket.extensions.helpers import route, composite, authenticate_simple
from rsocket.rsocket_client import RSocketClient
from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator
from rsocket.transports.tcp import TransportTCP
Expand Down Expand Up @@ -143,9 +144,10 @@ async def request_fragmented_stream(socket: RSocketClient):


async def main():

connection = await asyncio.open_connection('localhost', 6565)

async with RSocketClient(TransportTCP(*connection),
async with RSocketClient(single_transport_provider(TransportTCP(*connection)),
metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA) as client:
await request_response(client)
await request_stream(client)
Expand Down
4 changes: 3 additions & 1 deletion examples/run_against_example_java_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rsocket.extensions.composite_metadata import CompositeMetadata
from rsocket.extensions.mimetypes import WellKnownMimeTypes
from rsocket.extensions.routing import RoutingMetadata
from rsocket.helpers import single_transport_provider
from rsocket.payload import Payload
from rsocket.rsocket_client import RSocketClient
from rsocket.transports.tcp import TransportTCP
Expand All @@ -34,7 +35,8 @@ def on_error(self, exception: Exception):
completion_event.set()

connection = await asyncio.open_connection('localhost', 6565)
async with RSocketClient(TransportTCP(*connection),

async with RSocketClient(single_transport_provider(TransportTCP(*connection)),
metadata_encoding=WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA.value.name,
data_encoding=WellKnownMimeTypes.APPLICATION_JSON.value.name) as client:
metadata = CompositeMetadata()
Expand Down
4 changes: 2 additions & 2 deletions rsocket/awaitable/awaitable_rsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._rsocket.__aexit__(exc_type, exc_val, exc_tb)

def connect(self):
return self._rsocket.connect()
async def connect(self):
return await self._rsocket.connect()

def close(self):
self._rsocket.close()
14 changes: 11 additions & 3 deletions rsocket/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class RSocketStreamAllocationFailure(RSocketError):
pass


class RSocketValueErrorException(RSocketError):
class RSocketValueError(RSocketError):
pass


class RSocketProtocolException(RSocketError):
class RSocketProtocolError(RSocketError):
def __init__(self, error_code: ErrorCode, data: Optional[str] = None):
self.error_code = error_code
self.data = data
Expand All @@ -47,7 +47,7 @@ def __str__(self) -> str:
return 'RSocket error %s(%s): "%s"' % (self.error_code.name, self.error_code.value, self.data or '')


class RSocketStreamIdInUse(RSocketProtocolException):
class RSocketStreamIdInUse(RSocketProtocolError):

def __init__(self, stream_id: int):
super().__init__(ErrorCode.REJECTED)
Expand All @@ -56,3 +56,11 @@ def __init__(self, stream_id: int):

class RSocketFrameFragmentDifferentType(RSocketError):
pass


class RSocketTransportError(RSocketError):
pass


class RSocketNoAvailableTransport(RSocketError):
pass
8 changes: 4 additions & 4 deletions rsocket/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Tuple, Optional

from rsocket.error_codes import ErrorCode
from rsocket.exceptions import RSocketProtocolException, ParseError, RSocketUnknownFrameType
from rsocket.exceptions import RSocketProtocolError, ParseError, RSocketUnknownFrameType
from rsocket.frame_helpers import is_flag_set, unpack_position, pack_position, unpack_24bit, pack_24bit, unpack_32bit, \
ensure_bytes

Expand Down Expand Up @@ -626,7 +626,7 @@ def parse_or_ignore(buffer: bytes) -> Optional[Frame]:
return frame
except Exception as exception:
if not header.flags_ignore:
raise RSocketProtocolException(ErrorCode.CONNECTION_ERROR, str(exception)) from exception
raise RSocketProtocolError(ErrorCode.CONNECTION_ERROR, str(exception)) from exception


def is_fragmentable_frame(frame: Frame) -> bool:
Expand All @@ -643,7 +643,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame
frame = ErrorFrame()
frame.stream_id = stream_id

if isinstance(exception, RSocketProtocolException):
if isinstance(exception, RSocketProtocolError):
frame.error_code = exception.error_code
frame.data = ensure_bytes(exception.data)
else:
Expand All @@ -655,7 +655,7 @@ def exception_to_error_frame(stream_id: int, exception: Exception) -> ErrorFrame

def error_frame_to_exception(frame: ErrorFrame) -> Exception:
if frame.error_code != ErrorCode.APPLICATION_ERROR:
return RSocketProtocolException(frame.error_code, data=frame.data.decode())
return RSocketProtocolError(frame.error_code, data=frame.data.decode())

return RuntimeError(frame.data.decode('utf-8'))

Expand Down
22 changes: 18 additions & 4 deletions rsocket/helpers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import asyncio
from typing import Optional
from contextlib import contextmanager
from typing import Optional, Any

from reactivestreams.publisher import DefaultPublisher
from reactivestreams.subscriber import Subscriber
from reactivestreams.subscription import DefaultSubscription
from rsocket.exceptions import RSocketTransportError
from rsocket.frame import Frame
from rsocket.payload import Payload

_default = object()


def create_future(payload: Optional[Payload] = _default) -> asyncio.Future:
def create_future(value: Optional[Any] = _default) -> asyncio.Future:
future = asyncio.get_event_loop().create_future()

if payload is not _default:
future.set_result(payload)
if value is not _default:
future.set_result(value)

return future

Expand Down Expand Up @@ -50,3 +52,15 @@ def __eq__(self, other):

def __hash__(self):
return hash((self.id, self.name))


@contextmanager
def wrap_transport_exception():
try:
yield
except Exception as exception:
raise RSocketTransportError from exception


async def single_transport_provider(transport):
yield transport
6 changes: 3 additions & 3 deletions rsocket/load_balancer/load_balancer_rsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def request_stream(self, payload: Payload) -> Union[BackpressureApi, Publisher]:
def metadata_push(self, metadata: bytes):
self._select_client().metadata_push(metadata)

def connect(self):
self._strategy.connect()
async def connect(self):
await self._strategy.connect()

async def close(self):
await self._strategy.close()

async def __aenter__(self) -> RSocket:
self._strategy.connect()
await self._strategy.connect()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
2 changes: 1 addition & 1 deletion rsocket/load_balancer/load_balancer_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def select(self) -> RSocket:
...

@abc.abstractmethod
def connect(self):
async def connect(self):
...

@abc.abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions rsocket/load_balancer/random_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def select(self) -> RSocket:
random_client_id = random.randint(0, len(self._pool))
return self._pool[random_client_id]

def connect(self):
async def connect(self):
if self._auto_connect:
[client.connect() for client in self._pool]
[await client.connect() for client in self._pool]

async def close(self):
if self._auto_close:
Expand Down
4 changes: 2 additions & 2 deletions rsocket/load_balancer/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def select(self) -> RSocket:
self._current_index = (self._current_index + 1) % len(self._pool)
return client

def connect(self):
async def connect(self):
if self._auto_connect:
[client.connect() for client in self._pool]
[await client.connect() for client in self._pool]

async def close(self):
if self._auto_close:
Expand Down
13 changes: 10 additions & 3 deletions rsocket/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from asyncio import Future
from datetime import timedelta
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional

from reactivestreams.publisher import Publisher
from reactivestreams.subscriber import Subscriber
Expand Down Expand Up @@ -58,7 +58,11 @@ async def on_error(self, error_code: ErrorCode, payload: Payload):
@abstractmethod
async def on_keepalive_timeout(self,
time_since_last_keepalive: timedelta,
cancel_all_streams: Callable):
rsocket):
...

@abstractmethod
async def on_connection_lost(self, rsocket, exception):
...

def _parse_composite_metadata(self, metadata: bytes) -> CompositeMetadata:
Expand Down Expand Up @@ -93,7 +97,10 @@ async def request_stream(self, payload: Payload) -> Publisher:
async def on_error(self, error_code: ErrorCode, payload: Payload):
logger().error('Error: %s, %s', error_code, payload)

async def on_connection_lost(self, rsocket, exception: Exception):
await rsocket.close()

async def on_keepalive_timeout(self,
time_since_last_keepalive: timedelta,
cancel_all_streams: Callable):
rsocket):
pass
2 changes: 1 addition & 1 deletion rsocket/rsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def metadata_push(self, metadata: bytes):
...

@abc.abstractmethod
def connect(self):
async def connect(self):
...

@abc.abstractmethod
Expand Down
Loading