diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d59d8117..67928416 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ v0.4.1 - Performance test examples available in *performance* folder - WSS (Secure websocket) example and support (aiohttp) - Refactored Websocket transport to allow providing either url or an existing websocket +- Added command line tool (rsocket-py) v0.4.0 ====== diff --git a/README.md b/README.md index 52e16eb1..90730dc0 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ or install any of the extras: * aiohttp * quart * quic +* cli Example: @@ -62,7 +63,6 @@ all the examples | | ServerWithFragmentation | client_with_routing.py | | | server_quart_websocket.py | | client_websocket.py | | | server_aiohttp_websocket.py | | client_websocket.py | | -| server_aiohttp_websocket_secure.py | | client_wss.py | | # Build Status diff --git a/examples/client_websocket.py b/examples/client_websocket.py index f7a25ea5..a65ddebb 100644 --- a/examples/client_websocket.py +++ b/examples/client_websocket.py @@ -1,18 +1,34 @@ import asyncio import logging -import sys +import aiohttp +import asyncclick as click + +from rsocket.helpers import single_transport_provider from rsocket.payload import Payload +from rsocket.rsocket_client import RSocketClient +from rsocket.transports.aiohttp_websocket import TransportAioHttpClient from rsocket.transports.aiohttp_websocket import websocket_client -async def application(serve_port): - async with websocket_client('http://localhost:%s' % serve_port) as client: - result = await client.request_response(Payload(b'ping')) - print(result) +async def application(with_ssl: bool, serve_port: int): + if with_ssl: + async with aiohttp.ClientSession() as session: + async with session.ws_connect('wss://localhost:%s' % serve_port, verify_ssl=False) as websocket: + async with RSocketClient( + single_transport_provider(TransportAioHttpClient(websocket=websocket))) as client: + result = await client.request_response(Payload(b'ping')) + print(result) + + else: + async with websocket_client('http://localhost:%s' % serve_port) as client: + result = await client.request_response(Payload(b'ping')) + print(result) -if __name__ == '__main__': - port = sys.argv[1] if len(sys.argv) > 1 else 6565 +@click.command() +@click.option('--with-ssl', is_flag=False, default=False) +@click.option('--port', is_flag=False, default=6565) +async def command(with_ssl, port: int): logging.basicConfig(level=logging.DEBUG) - asyncio.run(application(port)) + asyncio.run(application(with_ssl, port)) diff --git a/examples/client_wss.py b/examples/client_wss.py deleted file mode 100644 index 97d085df..00000000 --- a/examples/client_wss.py +++ /dev/null @@ -1,24 +0,0 @@ -import asyncio -import logging -import sys - -import aiohttp - -from rsocket.helpers import single_transport_provider -from rsocket.payload import Payload -from rsocket.rsocket_client import RSocketClient -from rsocket.transports.aiohttp_websocket import TransportAioHttpClient - - -async def application(serve_port): - async with aiohttp.ClientSession() as session: - async with session.ws_connect('wss://localhost:%s' % serve_port, verify_ssl=False) as websocket: - async with RSocketClient(single_transport_provider(TransportAioHttpClient(websocket=websocket))) as client: - result = await client.request_response(Payload(b'ping')) - print(result) - - -if __name__ == '__main__': - port = sys.argv[1] if len(sys.argv) > 1 else 6565 - logging.basicConfig(level=logging.DEBUG) - asyncio.run(application(port)) diff --git a/examples/server_aiohttp_websocket.py b/examples/server_aiohttp_websocket.py index c77f0ad4..d976638d 100644 --- a/examples/server_aiohttp_websocket.py +++ b/examples/server_aiohttp_websocket.py @@ -1,13 +1,16 @@ import logging -import sys +import ssl +import asyncclick as click from aiohttp import web +from examples.fixtures import cert_gen from rsocket.helpers import create_future from rsocket.local_typing import Awaitable from rsocket.payload import Payload from rsocket.request_handler import BaseRequestHandler -from rsocket.transports.aiohttp_websocket import websocket_handler_factory +from rsocket.rsocket_server import RSocketServer +from rsocket.transports.aiohttp_websocket import TransportAioHttpWebsocket class Handler(BaseRequestHandler): @@ -16,9 +19,36 @@ async def request_response(self, payload: Payload) -> Awaitable[Payload]: return create_future(Payload(b'pong')) -if __name__ == '__main__': - port = sys.argv[1] if len(sys.argv) > 1 else 6565 +def websocket_handler_factory( **kwargs): + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + transport = TransportAioHttpWebsocket(ws) + RSocketServer(transport, **kwargs) + await transport.handle_incoming_ws_messages() + return ws + + return websocket_handler + + +@click.command() +@click.option('--port', help='Port to listen on', default=6565, type=int) +@click.option('--with-ssl', is_flag=True, help='Enable SSL mode') +async def start_server(with_ssl: bool, port: int): logging.basicConfig(level=logging.DEBUG) app = web.Application() app.add_routes([web.get('/', websocket_handler_factory(handler_factory=Handler))]) - web.run_app(app, port=port) + + if with_ssl: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + + with cert_gen() as (certificate, key): + ssl_context.load_cert_chain(certificate, key) + else: + ssl_context = None + + await web._run_app(app, port=port, ssl_context=ssl_context) + + +if __name__ == '__main__': + start_server() diff --git a/examples/server_aiohttp_websocket_secure.py b/examples/server_aiohttp_websocket_secure.py deleted file mode 100644 index b60d14b1..00000000 --- a/examples/server_aiohttp_websocket_secure.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging -import ssl -import sys - -from aiohttp import web - -from examples.fixtures import cert_gen -from rsocket.helpers import create_future -from rsocket.local_typing import Awaitable -from rsocket.payload import Payload -from rsocket.request_handler import BaseRequestHandler -from rsocket.transports.aiohttp_websocket import websocket_handler_factory - - -class Handler(BaseRequestHandler): - - async def request_response(self, payload: Payload) -> Awaitable[Payload]: - return create_future(Payload(b'pong')) - - -if __name__ == '__main__': - port = sys.argv[1] if len(sys.argv) > 1 else 6565 - logging.basicConfig(level=logging.DEBUG) - app = web.Application() - app.add_routes([web.get('/', websocket_handler_factory(handler_factory=Handler))]) - - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - - with cert_gen() as (certificate, key): - ssl_context.load_cert_chain(certificate, key) - - web.run_app(app, port=port, ssl_context=ssl_context) diff --git a/examples/server_with_routing.py b/examples/server_with_routing.py index b1b7d31e..aeb406eb 100644 --- a/examples/server_with_routing.py +++ b/examples/server_with_routing.py @@ -1,11 +1,15 @@ import asyncio import logging -import sys +import ssl from dataclasses import dataclass from datetime import timedelta from typing import Optional +import asyncclick as click +from aiohttp import web + from examples.example_fixtures import large_data1 +from examples.fixtures import cert_gen from examples.response_channel import response_stream_1, LoggingSubscriber from response_stream import response_stream_2 from rsocket.extensions.authentication import Authentication, AuthenticationSimple @@ -15,6 +19,7 @@ from rsocket.routing.request_router import RequestRouter from rsocket.routing.routing_request_handler import RoutingRequestHandler from rsocket.rsocket_server import RSocketServer +from rsocket.transports.aiohttp_websocket import TransportAioHttpWebsocket from rsocket.transports.tcp import TransportTCP router = RequestRouter() @@ -106,16 +111,49 @@ def handle_client(reader, writer): RSocketServer(TransportTCP(reader, writer), handler_factory=handler_factory) -async def run_server(server_port): - logging.info('Starting server at localhost:%s', server_port) +def websocket_handler_factory(**kwargs): + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + transport = TransportAioHttpWebsocket(ws) + RSocketServer(transport, **kwargs) + await transport.handle_incoming_ws_messages() + return ws + + return websocket_handler + + +@click.command() +@click.option('--port', help='Port to listen on', default=6565, type=int) +@click.option('--with-ssl', is_flag=True, help='Enable SSL mode') +@click.option('--transport', is_flag=False, default='tcp') +async def start_server(with_ssl: bool, port: int, transport: str): + logging.basicConfig(level=logging.DEBUG) + + logging.info(f'Starting {transport} server at localhost:{port}') + + if transport in ['ws', 'wss']: + app = web.Application() + app.add_routes([web.get('/', websocket_handler_factory(handler_factory=handler_factory))]) + + if with_ssl: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + + with cert_gen() as (certificate, key): + ssl_context.load_cert_chain(certificate, key) + else: + ssl_context = None - server = await asyncio.start_server(handle_client, 'localhost', server_port) + await web._run_app(app, port=port, ssl_context=ssl_context) + elif transport == 'tcp': - async with server: - await server.serve_forever() + server = await asyncio.start_server(handle_client, 'localhost', port) + + async with server: + await server.serve_forever() + else: + raise Exception(f'Unsupported transport {transport}') if __name__ == '__main__': - port = sys.argv[1] if len(sys.argv) > 1 else 6565 - logging.basicConfig(level=logging.DEBUG) - asyncio.run(run_server(port)) + start_server() diff --git a/examples/test_examples.py b/examples/test_examples.py index ab17bef5..fdaa8124 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -90,11 +90,12 @@ def test_client_server_over_websocket_aiohttp(unused_tcp_port): def test_client_server_over_websocket_secure_aiohttp(unused_tcp_port): - pid = os.spawnlp(os.P_NOWAIT, 'python3', 'python3', './server_aiohttp_websocket_secure.py', str(unused_tcp_port)) + pid = os.spawnlp(os.P_NOWAIT, 'python3', 'python3', 'server_aiohttp_websocket.py', '--port', str(unused_tcp_port), + '--with-ssl') try: sleep(2) - client = subprocess.Popen(['python3', './client_wss.py', str(unused_tcp_port)]) + client = subprocess.Popen(['python3', './client_websocket.py', '--port', str(unused_tcp_port), '--with-ssl']) client.wait(timeout=3) assert client.returncode == 0 @@ -103,11 +104,11 @@ def test_client_server_over_websocket_secure_aiohttp(unused_tcp_port): def test_client_server_over_websocket_quart(unused_tcp_port): - pid = os.spawnlp(os.P_NOWAIT, 'python3', 'python3', './server_quart_websocket.py', str(unused_tcp_port)) + pid = os.spawnlp(os.P_NOWAIT, 'python3', 'python3', './server_quart_websocket.py', '--port', str(unused_tcp_port)) try: sleep(2) - client = subprocess.Popen(['python3', './client_websocket.py', str(unused_tcp_port)]) + client = subprocess.Popen(['python3', './client_websocket.py', '--port', str(unused_tcp_port)]) client.wait(timeout=3) assert client.returncode == 0 diff --git a/requirements.txt b/requirements.txt index 4fd6d968..a9b87dd4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ quart==0.18.2 coveralls==3.3.1 aioquic==0.9.20 reactivex==4.0.4 -starlette==0.16.0 \ No newline at end of file +starlette==0.16.0 +asyncclick==8.1.3.4 \ No newline at end of file diff --git a/rsocket/__init__.py b/rsocket/__init__.py index e69de29b..f0ede3d3 100644 --- a/rsocket/__init__.py +++ b/rsocket/__init__.py @@ -0,0 +1 @@ +__version__ = '0.4.1' diff --git a/rsocket/awaitable/awaitable_rsocket.py b/rsocket/awaitable/awaitable_rsocket.py index 64e2764d..43fc4fb2 100644 --- a/rsocket/awaitable/awaitable_rsocket.py +++ b/rsocket/awaitable/awaitable_rsocket.py @@ -24,20 +24,20 @@ async def request_response(self, payload: Payload) -> Payload: async def request_stream(self, payload: Payload, - initial_request_n=MAX_REQUEST_N) -> List[Payload]: - subscriber = CollectorSubscriber() + limit_rate=MAX_REQUEST_N) -> List[Payload]: + subscriber = CollectorSubscriber(limit_rate) - self._rsocket.request_stream(payload).initial_request_n(initial_request_n).subscribe(subscriber) + self._rsocket.request_stream(payload).initial_request_n(limit_rate).subscribe(subscriber) return await subscriber.run() async def request_channel(self, payload: Payload, publisher: Optional[Publisher] = None, - initial_request_n=MAX_REQUEST_N) -> List[Payload]: - subscriber = CollectorSubscriber() + limit_rate=MAX_REQUEST_N) -> List[Payload]: + subscriber = CollectorSubscriber(limit_rate) - self._rsocket.request_channel(payload, publisher).initial_request_n(initial_request_n).subscribe(subscriber) + self._rsocket.request_channel(payload, publisher).initial_request_n(limit_rate).subscribe(subscriber) return await subscriber.run() diff --git a/rsocket/awaitable/collector_subscriber.py b/rsocket/awaitable/collector_subscriber.py index 025f13a3..cf29c599 100644 --- a/rsocket/awaitable/collector_subscriber.py +++ b/rsocket/awaitable/collector_subscriber.py @@ -2,11 +2,16 @@ from reactivestreams.subscriber import Subscriber from reactivestreams.subscription import DefaultSubscription +from rsocket.frame import MAX_REQUEST_N class CollectorSubscriber(Subscriber): - def __init__(self) -> None: + def __init__(self, limit_rate=MAX_REQUEST_N, limit_count=None) -> None: + self._limit_count = limit_count + self._limit_rate = limit_rate + self._received_count = 0 + self._total_received_count = 0 self.is_done = asyncio.Event() self.error = None self.values = [] @@ -21,8 +26,18 @@ def on_subscribe(self, subscription: DefaultSubscription): def on_next(self, value, is_complete=False): self.values.append(value) + self._received_count += 1 + self._total_received_count += 1 + if is_complete: self.is_done.set() + elif self._limit_count is not None and self._limit_count == self._total_received_count: + self.subscription.cancel() + self.is_done.set() + else: + if self._received_count == self._limit_rate: + self._received_count = 0 + self.subscription.request(self._limit_rate) def on_error(self, exception: Exception): self.error = exception diff --git a/rsocket/cli/__init__.py b/rsocket/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py new file mode 100644 index 00000000..b87072c1 --- /dev/null +++ b/rsocket/cli/command.py @@ -0,0 +1,283 @@ +import asyncio +import logging +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Optional, Type, Collection, List + +import aiohttp +import asyncclick as click +from werkzeug.routing import Map + +from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket +from rsocket.extensions.helpers import route, composite, authenticate_simple, authenticate_bearer +from rsocket.extensions.mimetypes import WellKnownMimeTypes +from rsocket.frame import MAX_REQUEST_N +from rsocket.frame_helpers import ensure_bytes, safe_len +from rsocket.helpers import single_transport_provider +from rsocket.payload import Payload +from rsocket.rsocket_client import RSocketClient +from rsocket.transports.abstract_messaging import AbstractMessagingTransport +from rsocket.transports.aiohttp_websocket import TransportAioHttpClient +from rsocket.transports.tcp import TransportTCP +from importlib.metadata import version as get_version + + +@dataclass(frozen=True) +class RSocketUri: + host: str + port: str + schema: str + path: Optional[str] = None + original_uri: Optional[str] = None + + +def parse_uri(uri: str) -> RSocketUri: + schema, rest = uri.split(':', 1) + rest = rest.strip('/') + host_port = rest.split('/', 1) + host, port = host_port[0].split(':') + + if len(host_port) > 1: + rest = host_port[1] + else: + rest = None + + return RSocketUri(host, port, schema, rest, uri) + + +@asynccontextmanager +async def transport_from_uri(uri: RSocketUri, + verify_ssl=True, + headers: Optional[Map] = None) -> Type[AbstractMessagingTransport]: + if uri.schema == 'tcp': + connection = await asyncio.open_connection(uri.host, uri.port) + yield TransportTCP(*connection) + elif uri.schema in ['wss', 'ws']: + async with aiohttp.ClientSession() as session: + async with session.ws_connect(uri.original_uri, + verify_ssl=verify_ssl, + headers=headers) as websocket: + yield TransportAioHttpClient(websocket=websocket) + else: + raise Exception('Unsupported schema in CLI') + + +def build_composite_metadata(auth_simple: Optional[str], + route_value: Optional[str], + auth_bearer: Optional[str]): + composite_items = [] + + if route_value is not None: + composite_items.append(route(route_value)) + + if auth_simple is not None: + composite_items.append(authenticate_simple(*auth_simple.split(':'))) + + if auth_bearer is not None: + composite_items.append(authenticate_bearer(auth_bearer)) + + return composite_items + + +@asynccontextmanager +async def create_client(parsed_uri, + data_mime_type, + metadata_mime_type, + setup_payload, + allow_untrusted_ssl=False, + http_headers=None): + async with transport_from_uri(parsed_uri, verify_ssl=not allow_untrusted_ssl, headers=http_headers) as transport: + async with RSocketClient(single_transport_provider(transport), + data_encoding=data_mime_type or WellKnownMimeTypes.APPLICATION_JSON, + metadata_encoding=metadata_mime_type or WellKnownMimeTypes.APPLICATION_JSON, + setup_payload=setup_payload) as client: + yield AwaitableRSocket(client) + + +@click.command(name='rsocket-py', help='Supported connection strings: tcp/ws/wss') +@click.option('--request', is_flag=True, + help='Request response') +@click.option('--stream', is_flag=True, + help='Request stream') +@click.option('--channel', is_flag=True, + help='Request channel') +@click.option('--fnf', is_flag=True, + help='Fire and Forget') +@click.option('-d', '--data', is_flag=False, + help='Data. Use "-" to read data from standard input. (default: )') +@click.option('-l', '--load', is_flag=False, + help='Load a file as Data. (e.g. ./foo.txt, /tmp/foo.txt)') +@click.option('-m', '--metadata', is_flag=False, default=None, + help='Metadata (default: )') +@click.option('-r', '--route', 'route_value', is_flag=False, default=None, + help='Enable Routing Metadata Extension') +@click.option('--limitRate', 'limit_rate', is_flag=False, default=None, type=int, + help='Enable limitRate(rate)') +@click.option('--take', 'take_n', is_flag=False, default=None, type=int, + help='Enable take(n)') +@click.option('-u', '--as', '--authSimple', 'auth_simple', is_flag=False, default=None, + help='Enable Authentication Metadata Extension (Simple). The format must be "username: password"') +@click.option('--sd', '--setupData', 'setup_data', is_flag=False, default=None, + help='Data for Setup payload') +@click.option('--sm', '--setupMetadata', 'setup_metadata', is_flag=False, default=None, + help='Metadata for Setup payload') +@click.option('--ab', '--authBearer', 'auth_bearer', is_flag=False, default=None, + help='Enable Authentication Metadata Extension (Bearer)') +@click.option('--dataMimeType', '--dmt', 'data_mime_type', is_flag=False, + help='MimeType for data (default: application/json)') +@click.option('--metadataMimeType', '--mmt', 'metadata_mime_type', is_flag=False, + help='MimeType for metadata (default:application/json)') +@click.option('--allowUntrustedSsl', 'allow_untrusted_ssl', is_flag=True, default=False, + help='Do not verify SSL certificate (for wss:// urls)') +@click.option('--httpHeader', 'http_header', multiple=True, + help='ws/wss headers') +@click.option('--debug', is_flag=True, + help='Show debug log') +@click.option('--quiet', '-q', is_flag=True, + help='Disable the output on next') +@click.option('--version', is_flag=True, + help='Print version') +@click.argument('uri') +async def command(data, load, + metadata, route_value, auth_simple, auth_bearer, + limit_rate, take_n, allow_untrusted_ssl, + setup_data, setup_metadata, + http_header, + data_mime_type, metadata_mime_type, + request, stream, channel, fnf, + uri, debug, version, quiet): + if version: + try: + print(get_version('rsocket')) + except Exception: + print('Failed to find version') + return + + if quiet: + logging.basicConfig(handlers=[]) + + if debug: + logging.basicConfig(level=logging.DEBUG) + + if take_n == 0: + return + + http_headers = parse_headers(http_header) + + composite_items = build_composite_metadata(auth_simple, route_value, auth_bearer) + + async with create_client(parse_uri(uri), + data_mime_type, + normalize_metadata_mime_type(composite_items, metadata_mime_type), + create_setup_payload(setup_data, setup_metadata), + allow_untrusted_ssl=allow_untrusted_ssl, + http_headers=http_headers + ) as client: + + result = await execute_request(client, + channel, + fnf, + normalize_limit_rate(limit_rate), + create_request_payload(data, load, metadata, composite_items), + request, + stream) + + if not quiet: + output_result(result) + + +def parse_headers(http_headers): + if safe_len(http_headers) > 0: + headers = dict() + + for header in http_headers: + parts = header.split('=', 2) + headers[parts[0]] = parts[1] + + return headers + + return None + + +def normalize_metadata_mime_type(composite_items, metadata_mime_type): + if len(composite_items) > 0: + metadata_mime_type = WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA + + return metadata_mime_type + + +def create_request_payload(data: Optional[str], + load: Optional[str], + metadata: Optional[str], + composite_items: List) -> Payload: + data = normalize_data(data, load) + metadata_value = get_metadata_value(composite_items, metadata) + return Payload(data, metadata_value) + + +def output_result(result): + if isinstance(result, Payload): + print(result.data.decode('utf-8')) + elif isinstance(result, Collection): + print([p.data.decode('utf-8') for p in result]) + + +async def execute_request(awaitable_client, channel, fnf, limit_rate, payload, request, stream): + result = None + + if request: + result = await awaitable_client.request_response(payload) + elif stream: + result = await awaitable_client.request_stream(payload, limit_rate=limit_rate) + elif channel: + result = await awaitable_client.request_channel(payload, limit_rate=limit_rate) + elif fnf: + await awaitable_client.fire_and_forget(payload) + + return result + + +def get_metadata_value(composite_items: List, metadata: Optional[str]) -> bytes: + if len(composite_items) > 0: + metadata_value = composite(*composite_items) + else: + metadata_value = metadata + + return ensure_bytes(metadata_value) + + +def create_setup_payload(setup_data: Optional[str], setup_metadata: Optional[str]) -> Optional[Payload]: + setup_payload = None + + if setup_data is not None or setup_metadata is not None: + setup_payload = Payload( + ensure_bytes(setup_data), + ensure_bytes(setup_metadata) + ) + + return setup_payload + + +def normalize_data(data: Optional[str], load: Optional[str]) -> bytes: + if data == '-': + stdin_text = click.get_text_stream('stdin') + return ensure_bytes(stdin_text.read()) + + if load is not None: + with open(load) as fd: + return ensure_bytes(fd.read()) + + return ensure_bytes(data) + + +def normalize_limit_rate(limit_rate): + if limit_rate is not None and not limit_rate > 0: + limit_rate = MAX_REQUEST_N + else: + limit_rate = MAX_REQUEST_N + + return limit_rate + + +if __name__ == '__main__': + command() diff --git a/setup.py b/setup.py index 0d3c24e2..f2c1d108 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,13 @@ 'reactivex': {'reactivex >= 4.0.0'}, 'aiohttp': {'aiohttp >= 3.0.0'}, 'quart': {'quart >= 0.15.0'}, - 'quic': {'aioquic >= 0.9.0'} + 'quic': {'aioquic >= 0.9.0'}, + 'cli': {'asyncclick >= 8.0.0'} + }, + entry_points={ + 'console_scripts': [ + 'rsocket-py = rsocket.cli.command:command [cli]', + ], }, classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/tests/rsocket/helpers.py b/tests/rsocket/helpers.py index 010e233d..096f5875 100644 --- a/tests/rsocket/helpers.py +++ b/tests/rsocket/helpers.py @@ -6,7 +6,7 @@ from typing import Tuple, Any from typing import Type, Callable -from rsocket.frame_helpers import str_to_bytes +from rsocket.frame_helpers import str_to_bytes, ensure_bytes from rsocket.helpers import create_future, noop from rsocket.logger import logger from rsocket.payload import Payload @@ -97,3 +97,7 @@ def get_components(pipe) -> Tuple[RSocketServer, RSocketClient]: def to_json_bytes(item: Any) -> bytes: return str_to_bytes(json.dumps(item)) + + +def create_data(base: bytes, multiplier: int, limit: float = None): + return b''.join([ensure_bytes(str(i)) + base for i in range(multiplier)])[0:limit] diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py new file mode 100644 index 00000000..f075f116 --- /dev/null +++ b/tests/rsocket/test_cli_command.py @@ -0,0 +1,93 @@ +import io +import sys +import tempfile + +from rsocket.cli.command import parse_uri, build_composite_metadata, create_request_payload, get_metadata_value, \ + create_setup_payload, normalize_data, normalize_limit_rate +from rsocket.frame import MAX_REQUEST_N +from tests.rsocket.helpers import create_data + + +def test_parse_uri(): + parsed = parse_uri('wss://localhost:6565') + + assert parsed.schema == 'wss' + assert parsed.port == '6565' + assert parsed.host == 'localhost' + + +def test_build_composite_metadata(): + composite = build_composite_metadata( + None, None, None + ) + + assert len(composite) == 0 + + +def test_create_request_payload(): + payload = create_request_payload( + None, None, None, [] + ) + + assert payload.data is None + assert payload.metadata is None + + +def test_get_metadata_value(): + result = get_metadata_value([], None) + + assert result is None + + +def test_create_setup_payload(): + result = create_setup_payload(None, None) + + assert result is None + + +def test_normalize_data(): + data = normalize_data(None, None) + + assert data is None + + +def test_normalize_data_from_file(): + with tempfile.NamedTemporaryFile() as fd: + fixture_data = create_data(b'1234567890', 20) + fd.write(fixture_data) + fd.flush() + + data = normalize_data(None, fd.name) + + assert data == fixture_data + + +def test_normalize_data_from_stdin(): + fixture_data = create_data(b'1234567890', 20) + stdin = io.BytesIO(fixture_data) + sys.stdin = stdin + + data = normalize_data('-', None) + + assert data == fixture_data + + +def test_normalize_data_from_stdin_takes_precedence_over_load_from_file(): + with tempfile.NamedTemporaryFile() as fd: + fixture_data_file = create_data(b'1234567890', 20) + fd.write(fixture_data_file) + fd.flush() + + fixture_data_stdin = create_data(b'0987654321', 20) + stdin = io.BytesIO(fixture_data_stdin) + sys.stdin = stdin + + data = normalize_data('-', fd.name) + + assert data == fixture_data_stdin + + +def test_normalize_limit_rate(): + result = normalize_limit_rate(None) + + assert result == MAX_REQUEST_N diff --git a/tests/rsocket/test_fragments.py b/tests/rsocket/test_fragments.py index 714df9e9..5e79e5de 100644 --- a/tests/rsocket/test_fragments.py +++ b/tests/rsocket/test_fragments.py @@ -8,10 +8,7 @@ from rsocket.frame_fragment_cache import FrameFragmentCache from rsocket.frame_helpers import ensure_bytes from rsocket.payload import Payload - - -def create_data(base: bytes, multiplier: int, limit: float = None): - return b''.join([ensure_bytes(str(i)) + base for i in range(multiplier)])[0:limit] +from tests.rsocket.helpers import create_data def test_create_data(): diff --git a/tests/rsocket/test_request_stream.py b/tests/rsocket/test_request_stream.py index c9ed844d..d2b711d4 100644 --- a/tests/rsocket/test_request_stream.py +++ b/tests/rsocket/test_request_stream.py @@ -200,15 +200,9 @@ def request(self, n: int): async def request_stream(self, payload: Payload) -> Publisher: return self - class StreamSubscriber(CollectorSubscriber): - - def on_next(self, value, is_complete=False): - super().on_next(value, is_complete) - self.subscription.request(1) - server.set_handler_using_factory(Handler) - stream_subscriber = StreamSubscriber() + stream_subscriber = CollectorSubscriber(limit_rate=1) client.request_stream(Payload()).initial_request_n(1).subscribe(stream_subscriber)