From 1cf5583331974bbaa1adcfa672585d7b4d0665d1 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 16:59:05 +0300 Subject: [PATCH 01/13] more compatibility with making/rsc project. --- rsocket/cli/command.py | 98 ++++++++++++++++++++++++------- tests/rsocket/test_cli_command.py | 13 +++- 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index b87072c1..b6708756 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -2,6 +2,7 @@ import logging from contextlib import asynccontextmanager from dataclasses import dataclass +from enum import Enum, unique from typing import Optional, Type, Collection, List import aiohttp @@ -22,10 +23,22 @@ from importlib.metadata import version as get_version +@unique +class RequestType(Enum): + response = 'REQUEST_RESPONSE' + stream = 'REQUEST_STREAM' + channel = 'REQUEST_CHANNEL' + fnf = 'FIRE_AND_FORGET' + metadata_push = 'METADATA_PUSH' + + +interaction_models: List[str] = [str(e.value) for e in RequestType] + + @dataclass(frozen=True) class RSocketUri: host: str - port: str + port: int schema: str path: Optional[str] = None original_uri: Optional[str] = None @@ -34,11 +47,17 @@ class RSocketUri: def parse_uri(uri: str) -> RSocketUri: schema, rest = uri.split(':', 1) rest = rest.strip('/') - host_port = rest.split('/', 1) - host, port = host_port[0].split(':') + host_port_path = rest.split('/', 1) + host_port = host_port_path[0].split(':') + host = host_port[0] + + if len(host_port) == 1: + port = None + else: + port = int(host_port[1]) - if len(host_port) > 1: - rest = host_port[1] + if len(host_port_path) > 1: + rest = host_port_path[1] else: rest = None @@ -48,7 +67,8 @@ def parse_uri(uri: str) -> RSocketUri: @asynccontextmanager async def transport_from_uri(uri: RSocketUri, verify_ssl=True, - headers: Optional[Map] = None) -> Type[AbstractMessagingTransport]: + headers: Optional[Map] = None, + trust_cert: Optional[str] = None) -> Type[AbstractMessagingTransport]: if uri.schema == 'tcp': connection = await asyncio.open_connection(uri.host, uri.port) yield TransportTCP(*connection) @@ -85,8 +105,12 @@ async def create_client(parsed_uri, metadata_mime_type, setup_payload, allow_untrusted_ssl=False, - http_headers=None): - async with transport_from_uri(parsed_uri, verify_ssl=not allow_untrusted_ssl, headers=http_headers) as transport: + http_headers=None, + trust_cert=None): + async with transport_from_uri(parsed_uri, + verify_ssl=not allow_untrusted_ssl, + headers=http_headers, + trust_cert=trust_cert) as transport: async with RSocketClient(single_transport_provider(transport), data_encoding=data_mime_type or WellKnownMimeTypes.APPLICATION_JSON, metadata_encoding=metadata_mime_type or WellKnownMimeTypes.APPLICATION_JSON, @@ -94,7 +118,30 @@ async def create_client(parsed_uri, yield AwaitableRSocket(client) +def get_request_type(request: bool, + stream: bool, + fnf: bool, + metadata_push: bool, + channel: bool, + interaction_model: str) -> RequestType: + if interaction_model is not None: + return RequestType(interaction_model.upper()) + if request: + return RequestType.response + if stream: + return RequestType.stream + if channel: + return RequestType.channel + if fnf: + return RequestType.fnf + if metadata_push: + return RequestType.metadata_push + + @click.command(name='rsocket-py', help='Supported connection strings: tcp/ws/wss') +@click.option('--im', '--interactionModel', 'interaction_model', is_flag=False, + type=click.Choice(interaction_models, case_sensitive=False), + help='Interaction Model') @click.option('--request', is_flag=True, help='Request response') @click.option('--stream', is_flag=True, @@ -103,6 +150,8 @@ async def create_client(parsed_uri, help='Request channel') @click.option('--fnf', is_flag=True, help='Fire and Forget') +@click.option('--metadataPush', 'metadata_push', is_flag=True, + help='Metadata Push') @click.option('-d', '--data', is_flag=False, help='Data. Use "-" to read data from standard input. (default: )') @click.option('-l', '--load', is_flag=False, @@ -131,6 +180,8 @@ async def create_client(parsed_uri, help='Do not verify SSL certificate (for wss:// urls)') @click.option('--httpHeader', 'http_header', multiple=True, help='ws/wss headers') +@click.option('--trustCert', 'trust_cert', is_flag=False, + help='PEM file for a trusted certificate. (e.g. ./foo.crt, /tmp/foo.crt)') @click.option('--debug', is_flag=True, help='Show debug log') @click.option('--quiet', '-q', is_flag=True, @@ -141,10 +192,10 @@ async def create_client(parsed_uri, async def command(data, load, metadata, route_value, auth_simple, auth_bearer, limit_rate, take_n, allow_untrusted_ssl, - setup_data, setup_metadata, - http_header, + setup_data, setup_metadata, interaction_model, + http_header, metadata_push, data_mime_type, metadata_mime_type, - request, stream, channel, fnf, + request, stream, channel, fnf, trust_cert, uri, debug, version, quiet): if version: try: @@ -162,6 +213,8 @@ async def command(data, load, if take_n == 0: return + request_type = get_request_type(request, stream, fnf, metadata_push, channel, interaction_model) + http_headers = parse_headers(http_header) composite_items = build_composite_metadata(auth_simple, route_value, auth_bearer) @@ -171,16 +224,14 @@ async def command(data, load, normalize_metadata_mime_type(composite_items, metadata_mime_type), create_setup_payload(setup_data, setup_metadata), allow_untrusted_ssl=allow_untrusted_ssl, - http_headers=http_headers + http_headers=http_headers, + trust_cert=trust_cert ) as client: result = await execute_request(client, - channel, - fnf, + request_type, normalize_limit_rate(limit_rate), - create_request_payload(data, load, metadata, composite_items), - request, - stream) + create_request_payload(data, load, metadata, composite_items)) if not quiet: output_result(result) @@ -222,17 +273,20 @@ def output_result(result): print([p.data.decode('utf-8') for p in result]) -async def execute_request(awaitable_client, channel, fnf, limit_rate, payload, request, stream): +async def execute_request(awaitable_client: AwaitableRSocket, request_type: RequestType, limit_rate: int, + payload: Payload): result = None - if request: + if request_type is RequestType.response: result = await awaitable_client.request_response(payload) - elif stream: + elif request_type is RequestType.stream: result = await awaitable_client.request_stream(payload, limit_rate=limit_rate) - elif channel: + elif request_type is RequestType.channel: result = await awaitable_client.request_channel(payload, limit_rate=limit_rate) - elif fnf: + elif request_type is RequestType.fnf: await awaitable_client.fire_and_forget(payload) + elif request_type is RequestType.metadata_push: + await awaitable_client.metadata_push(payload.metadata) return result diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index f075f116..941d7c44 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -9,11 +9,20 @@ def test_parse_uri(): - parsed = parse_uri('wss://localhost:6565') + parsed = parse_uri('tcp://localhost:6565') + + assert parsed.schema == 'tcp' + assert parsed.port == 6565 + assert parsed.host == 'localhost' + + +def test_parse_uri_wss(): + parsed = parse_uri('wss://localhost/path') assert parsed.schema == 'wss' - assert parsed.port == '6565' + assert parsed.port is None assert parsed.host == 'localhost' + assert parsed.path == 'path' def test_build_composite_metadata(): From 00efe109fb8b7c99c82651c48e1a4fd0abee7d5e Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:24:44 +0300 Subject: [PATCH 02/13] allow --version to work without specifying uri --- rsocket/cli/command.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index b6708756..794b98e2 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -188,8 +188,9 @@ def get_request_type(request: bool, help='Disable the output on next') @click.option('--version', is_flag=True, help='Print version') -@click.argument('uri') -async def command(data, load, +@click.argument('uri', required=False) +@click.pass_context +async def command(context, data, load, metadata, route_value, auth_simple, auth_bearer, limit_rate, take_n, allow_untrusted_ssl, setup_data, setup_metadata, interaction_model, @@ -204,6 +205,9 @@ async def command(data, load, print('Failed to find version') return + if uri is None: + raise click.MissingParameter(param=context.command.params[-1]) + if quiet: logging.basicConfig(handlers=[]) From 452b4637b49efadfa35eeeb4437b7a79d0c84c9e Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:34:19 +0300 Subject: [PATCH 03/13] added using ssl certificate --- examples/fixtures.py | 2 +- examples/server_with_routing.py | 17 ++++++++++------- rsocket/cli/command.py | 7 +++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/fixtures.py b/examples/fixtures.py index 977d506a..382e1a0e 100644 --- a/examples/fixtures.py +++ b/examples/fixtures.py @@ -8,7 +8,7 @@ @contextmanager def cert_gen(emailAddress="emailAddress", - commonName="commonName", + commonName="localhost", countryName="NT", localityName="localityName", stateOrProvinceName="stateOrProvinceName", diff --git a/examples/server_with_routing.py b/examples/server_with_routing.py index aeb406eb..988d6f2b 100644 --- a/examples/server_with_routing.py +++ b/examples/server_with_routing.py @@ -136,15 +136,18 @@ async def start_server(with_ssl: bool, port: int, transport: str): app = web.Application() app.add_routes([web.get('/', websocket_handler_factory(handler_factory=handler_factory))]) - if with_ssl: - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + with cert_gen() as (certificate_path, key_path): + if with_ssl: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - with cert_gen() as (certificate, key): - ssl_context.load_cert_chain(certificate, key) - else: - ssl_context = None + logging.info('Certificate %s', certificate_path) + logging.info('Private-key %s', key_path) - await web._run_app(app, port=port, ssl_context=ssl_context) + ssl_context.load_cert_chain(certificate_path, key_path) + else: + ssl_context = None + + await web._run_app(app, port=port, ssl_context=ssl_context) elif transport == 'tcp': server = await asyncio.start_server(handle_client, 'localhost', port) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index 794b98e2..f98109ab 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -1,5 +1,6 @@ import asyncio import logging +import ssl from contextlib import asynccontextmanager from dataclasses import dataclass from enum import Enum, unique @@ -74,8 +75,14 @@ async def transport_from_uri(uri: RSocketUri, yield TransportTCP(*connection) elif uri.schema in ['wss', 'ws']: async with aiohttp.ClientSession() as session: + if trust_cert is not None: + ssl_context = ssl.create_default_context(cafile=trust_cert) + else: + ssl_context = None + async with session.ws_connect(uri.original_uri, verify_ssl=verify_ssl, + ssl_context=ssl_context, headers=headers) as websocket: yield TransportAioHttpClient(websocket=websocket) else: From 70319f4814ae3d5e539bf0c8588684afd006feb2 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:35:09 +0300 Subject: [PATCH 04/13] refactoring examples --- examples/fixtures.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/examples/fixtures.py b/examples/fixtures.py index 382e1a0e..08f0c307 100644 --- a/examples/fixtures.py +++ b/examples/fixtures.py @@ -7,18 +7,19 @@ @contextmanager -def cert_gen(emailAddress="emailAddress", - commonName="localhost", - countryName="NT", - localityName="localityName", - stateOrProvinceName="stateOrProvinceName", - organizationName="organizationName", - organizationUnitName="organizationUnitName", - serialNumber=0, - validityStartInSeconds=0, - validityEndInSeconds=None) -> Tuple[str, str]: - if validityEndInSeconds is None: - validityEndInSeconds = int(timedelta(days=3650).total_seconds()) +def cert_gen(email_address="emailAddress", + common_name="localhost", + country_name="NT", + locality_name="localityName", + state_or_province_name="stateOrProvinceName", + organization_name="organizationName", + organization_unit_name="organizationUnitName", + serial_number=0, + validity_start_in_seconds=0, + validity_end_in_seconds=None) -> Tuple[str, str]: + if validity_end_in_seconds is None: + validity_end_in_seconds = int(timedelta(days=3650).total_seconds()) + # can look at generated file using openssl: # openssl x509 -inform pem -in selfsigned.crt -noout -text # create a key pair @@ -27,16 +28,16 @@ def cert_gen(emailAddress="emailAddress", # create a self-signed cert cert = crypto.X509() - cert.get_subject().C = countryName - cert.get_subject().ST = stateOrProvinceName - cert.get_subject().L = localityName - cert.get_subject().O = organizationName - cert.get_subject().OU = organizationUnitName - cert.get_subject().CN = commonName - cert.get_subject().emailAddress = emailAddress - cert.set_serial_number(serialNumber) + cert.get_subject().C = country_name + cert.get_subject().ST = state_or_province_name + cert.get_subject().L = locality_name + cert.get_subject().O = organization_name + cert.get_subject().OU = organization_unit_name + cert.get_subject().CN = common_name + cert.get_subject().emailAddress = email_address + cert.set_serial_number(serial_number) cert.gmtime_adj_notBefore(0) - cert.gmtime_adj_notAfter(validityEndInSeconds) + cert.gmtime_adj_notAfter(validity_end_in_seconds) cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) cert.sign(k, 'sha512') From f7c7890e9b12096efc1fd262e34f83eba5c6748f Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:36:41 +0300 Subject: [PATCH 05/13] refactoring examples --- examples/fixtures.py | 64 +++++++++++++++++----------- examples/server_aiohttp_websocket.py | 4 +- examples/server_with_routing.py | 4 +- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/examples/fixtures.py b/examples/fixtures.py index 08f0c307..975e3b7b 100644 --- a/examples/fixtures.py +++ b/examples/fixtures.py @@ -7,26 +7,51 @@ @contextmanager -def cert_gen(email_address="emailAddress", - common_name="localhost", - country_name="NT", - locality_name="localityName", - state_or_province_name="stateOrProvinceName", - organization_name="organizationName", - organization_unit_name="organizationUnitName", - serial_number=0, - validity_start_in_seconds=0, - validity_end_in_seconds=None) -> Tuple[str, str]: +def generate_certificate_and_key(email_address="emailAddress", + common_name="localhost", + country_name="NT", + locality_name="localityName", + state_or_province_name="stateOrProvinceName", + organization_name="organizationName", + organization_unit_name="organizationUnitName", + serial_number=0, + validity_start_in_seconds=0, + validity_end_in_seconds=None) -> Tuple[str, str]: if validity_end_in_seconds is None: validity_end_in_seconds = int(timedelta(days=3650).total_seconds()) # can look at generated file using openssl: # openssl x509 -inform pem -in selfsigned.crt -noout -text # create a key pair + private_key = create_key() + + # create a self-signed cert + cert = create_self_signed_certificate(common_name, country_name, email_address, private_key, locality_name, + organization_name, + organization_unit_name, serial_number, state_or_province_name, + validity_end_in_seconds, validity_start_in_seconds) + + with tempfile.NamedTemporaryFile() as certificate_file: + with tempfile.NamedTemporaryFile() as key_file: + certificate_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) + certificate_file.flush() + + key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, private_key)) + key_file.flush() + + yield certificate_file.name, key_file.name + + +def create_key(): k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 4096) + return k - # create a self-signed cert + +def create_self_signed_certificate(common_name, country_name, email_address, private_key, locality_name, + organization_name, + organization_unit_name, serial_number, state_or_province_name, + validity_end_in_seconds, validity_start_in_seconds): cert = crypto.X509() cert.get_subject().C = country_name cert.get_subject().ST = state_or_province_name @@ -36,18 +61,9 @@ def cert_gen(email_address="emailAddress", cert.get_subject().CN = common_name cert.get_subject().emailAddress = email_address cert.set_serial_number(serial_number) - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore(validity_start_in_seconds) cert.gmtime_adj_notAfter(validity_end_in_seconds) cert.set_issuer(cert.get_subject()) - cert.set_pubkey(k) - cert.sign(k, 'sha512') - - with tempfile.NamedTemporaryFile() as certificate_file: - with tempfile.NamedTemporaryFile() as key_file: - certificate_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) - certificate_file.flush() - - key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)) - key_file.flush() - - yield certificate_file.name, key_file.name + cert.set_pubkey(private_key) + cert.sign(private_key, 'sha512') + return cert diff --git a/examples/server_aiohttp_websocket.py b/examples/server_aiohttp_websocket.py index d976638d..d62139a1 100644 --- a/examples/server_aiohttp_websocket.py +++ b/examples/server_aiohttp_websocket.py @@ -4,7 +4,7 @@ import asyncclick as click from aiohttp import web -from examples.fixtures import cert_gen +from examples.fixtures import generate_certificate_and_key from rsocket.helpers import create_future from rsocket.local_typing import Awaitable from rsocket.payload import Payload @@ -42,7 +42,7 @@ async def start_server(with_ssl: bool, port: int): if with_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - with cert_gen() as (certificate, key): + with generate_certificate_and_key() as (certificate, key): ssl_context.load_cert_chain(certificate, key) else: ssl_context = None diff --git a/examples/server_with_routing.py b/examples/server_with_routing.py index 988d6f2b..4762d6d0 100644 --- a/examples/server_with_routing.py +++ b/examples/server_with_routing.py @@ -9,7 +9,7 @@ from aiohttp import web from examples.example_fixtures import large_data1 -from examples.fixtures import cert_gen +from examples.fixtures import generate_certificate_and_key from examples.response_channel import response_stream_1, LoggingSubscriber from response_stream import response_stream_2 from rsocket.extensions.authentication import Authentication, AuthenticationSimple @@ -136,7 +136,7 @@ async def start_server(with_ssl: bool, port: int, transport: str): app = web.Application() app.add_routes([web.get('/', websocket_handler_factory(handler_factory=handler_factory))]) - with cert_gen() as (certificate_path, key_path): + with generate_certificate_and_key() as (certificate_path, key_path): if with_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) From 2a3db1dc8104b96eefba1783451266e86334a164 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:49:32 +0300 Subject: [PATCH 06/13] version bump and changelog update --- CHANGELOG.rst | 8 ++++++++ setup.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 67928416..ee19530e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Changelog --------- +v0.4.2 +====== +- Command line fixes: + - Support passing ssl certificate and http headers when using ws/wss + - Support requesting --version without the need to specify URI arguments + - Option --interactionModel to specify interaction (eg. request_response, request_stream) + - Added Metadata Push support + v0.4.1 ====== diff --git a/setup.py b/setup.py index f2c1d108..d2612169 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='rsocket', - version='0.4.1', + version='0.4.2', description='Python RSocket library', long_description=long_description, long_description_content_type='text/markdown', From 27db1f0bf746d8892abce714bddbe8ff2932a523 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:49:40 +0300 Subject: [PATCH 07/13] version bump and changelog update --- CHANGELOG.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ee19530e..7ba88b14 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,7 +9,6 @@ v0.4.2 - Option --interactionModel to specify interaction (eg. request_response, request_stream) - Added Metadata Push support - v0.4.1 ====== - Added running tests on python 3.11 and package classification From 195a29509663f57a64de8151c5a893864a63a296 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 18:58:27 +0300 Subject: [PATCH 08/13] cli - prevent multiple interaction methods or multiple authentication methods --- rsocket/cli/command.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index f98109ab..6dc72067 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -97,6 +97,9 @@ def build_composite_metadata(auth_simple: Optional[str], if route_value is not None: composite_items.append(route(route_value)) + if auth_simple is not None and auth_bearer is not None: + raise click.UsageError('Multiple authentication methods specified.') + if auth_simple is not None: composite_items.append(authenticate_simple(*auth_simple.split(':'))) @@ -130,7 +133,12 @@ def get_request_type(request: bool, fnf: bool, metadata_push: bool, channel: bool, - interaction_model: str) -> RequestType: + interaction_model: Optional[str]) -> RequestType: + interaction_options = list(filter(lambda _: _ is True, [request, stream, fnf, channel, metadata_push])) + + if len(interaction_options) >= 2 or (len(interaction_options) >= 1 and interaction_model is not None): + raise click.UsageError('Multiple interaction methods specified.') + if interaction_model is not None: return RequestType(interaction_model.upper()) if request: From 83ff73bb27ad856e6573d9c555520a6de6d6ab1c Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 19:18:02 +0300 Subject: [PATCH 09/13] cli - refactoring. added timeout option. added some compatibility with rsocket-cli --- rsocket/cli/command.py | 77 +++++++++++++++++++------------ tests/rsocket/test_cli_command.py | 2 +- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index 6dc72067..3b8067ec 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -4,7 +4,8 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from enum import Enum, unique -from typing import Optional, Type, Collection, List +from importlib.metadata import version as get_version +from typing import Optional, Type, Collection, List, Callable import aiohttp import asyncclick as click @@ -21,7 +22,6 @@ from rsocket.transports.abstract_messaging import AbstractMessagingTransport from rsocket.transports.aiohttp_websocket import TransportAioHttpClient from rsocket.transports.tcp import TransportTCP -from importlib.metadata import version as get_version @unique @@ -167,7 +167,7 @@ def get_request_type(request: bool, help='Fire and Forget') @click.option('--metadataPush', 'metadata_push', is_flag=True, help='Metadata Push') -@click.option('-d', '--data', is_flag=False, +@click.option('-d', '--data', '--input', 'data', is_flag=False, help='Data. Use "-" to read data from standard input. (default: )') @click.option('-l', '--load', is_flag=False, help='Load a file as Data. (e.g. ./foo.txt, /tmp/foo.txt)') @@ -181,19 +181,19 @@ def get_request_type(request: bool, help='Enable take(n)') @click.option('-u', '--as', '--authSimple', 'auth_simple', is_flag=False, default=None, help='Enable Authentication Metadata Extension (Simple). The format must be "username: password"') -@click.option('--sd', '--setupData', 'setup_data', is_flag=False, default=None, +@click.option('--sd', '--setup', '--setupData', 'setup_data', is_flag=False, default=None, help='Data for Setup payload') @click.option('--sm', '--setupMetadata', 'setup_metadata', is_flag=False, default=None, help='Metadata for Setup payload') @click.option('--ab', '--authBearer', 'auth_bearer', is_flag=False, default=None, help='Enable Authentication Metadata Extension (Bearer)') -@click.option('--dataMimeType', '--dmt', 'data_mime_type', is_flag=False, +@click.option('--dataMimeType', '--dataFormat', '--dmt', 'data_mime_type', is_flag=False, help='MimeType for data (default: application/json)') -@click.option('--metadataMimeType', '--mmt', 'metadata_mime_type', is_flag=False, +@click.option('--metadataMimeType', '--metadataFormat', '--mmt', 'metadata_mime_type', is_flag=False, help='MimeType for metadata (default:application/json)') @click.option('--allowUntrustedSsl', 'allow_untrusted_ssl', is_flag=True, default=False, help='Do not verify SSL certificate (for wss:// urls)') -@click.option('--httpHeader', 'http_header', multiple=True, +@click.option('-H', '--header', '--httpHeader', 'http_header', multiple=True, help='ws/wss headers') @click.option('--trustCert', 'trust_cert', is_flag=False, help='PEM file for a trusted certificate. (e.g. ./foo.crt, /tmp/foo.crt)') @@ -201,6 +201,8 @@ def get_request_type(request: bool, help='Show debug log') @click.option('--quiet', '-q', is_flag=True, help='Disable the output on next') +@click.option('--timeout', 'timeout_seconds', is_flag=False, type=int, + help='Timeout in seconds') @click.option('--version', is_flag=True, help='Print version') @click.argument('uri', required=False) @@ -209,7 +211,7 @@ async def command(context, data, load, metadata, route_value, auth_simple, auth_bearer, limit_rate, take_n, allow_untrusted_ssl, setup_data, setup_metadata, interaction_model, - http_header, metadata_push, + http_header, metadata_push, timeout_seconds, data_mime_type, metadata_mime_type, request, stream, channel, fnf, trust_cert, uri, debug, version, quiet): @@ -233,27 +235,43 @@ async def command(context, data, load, return request_type = get_request_type(request, stream, fnf, metadata_push, channel, interaction_model) - http_headers = parse_headers(http_header) - composite_items = build_composite_metadata(auth_simple, route_value, auth_bearer) + setup_payload = create_setup_payload(setup_data, setup_metadata) + metadata_value = get_metadata_value(composite_items, metadata) + metadata_mime_type = normalize_metadata_mime_type(composite_items, metadata_mime_type) + parsed_uri = parse_uri(uri) + + def payload_provider(): + return create_request_payload(data, load, metadata_value) + + future = run_request(request_type, limit_rate, payload_provider, + http_headers=http_headers, + allow_untrusted_ssl=allow_untrusted_ssl, + metadata_mime_type=metadata_mime_type, + data_mime_type=data_mime_type, + setup_payload=setup_payload, + trust_cert=trust_cert, + parsed_uri=parsed_uri) + + if timeout_seconds is not None: + result = await asyncio.wait_for(future, timeout_seconds) + else: + result = await future - async with create_client(parse_uri(uri), - data_mime_type, - normalize_metadata_mime_type(composite_items, metadata_mime_type), - create_setup_payload(setup_data, setup_metadata), - allow_untrusted_ssl=allow_untrusted_ssl, - http_headers=http_headers, - trust_cert=trust_cert - ) as client: + if not quiet: + output_result(result) - result = await execute_request(client, - request_type, - normalize_limit_rate(limit_rate), - create_request_payload(data, load, metadata, composite_items)) - if not quiet: - output_result(result) +async def run_request(request_type: RequestType, + limit_rate: Optional[int], + payload_provider: Callable[[], Payload], + **kwargs): + async with create_client(**kwargs) as client: + return await execute_request(client, + request_type, + normalize_limit_rate(limit_rate), + payload_provider()) def parse_headers(http_headers): @@ -278,11 +296,10 @@ def normalize_metadata_mime_type(composite_items, metadata_mime_type): def create_request_payload(data: Optional[str], load: Optional[str], - metadata: Optional[str], - composite_items: List) -> Payload: + metadata: Optional[bytes]) -> Payload: data = normalize_data(data, load) - metadata_value = get_metadata_value(composite_items, metadata) - return Payload(data, metadata_value) + + return Payload(data, metadata) def output_result(result): @@ -292,7 +309,9 @@ def output_result(result): print([p.data.decode('utf-8') for p in result]) -async def execute_request(awaitable_client: AwaitableRSocket, request_type: RequestType, limit_rate: int, +async def execute_request(awaitable_client: AwaitableRSocket, + request_type: RequestType, + limit_rate: int, payload: Payload): result = None diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index 941d7c44..80671c1c 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -35,7 +35,7 @@ def test_build_composite_metadata(): def test_create_request_payload(): payload = create_request_payload( - None, None, None, [] + None, None, None ) assert payload.data is None From 0bc201000e829f9639f372dd6f526c8e6c0edd2a Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 19:51:00 +0300 Subject: [PATCH 10/13] cli - added unit tests --- tests/rsocket/test_cli_command.py | 55 +++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index 80671c1c..38a63292 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -2,8 +2,11 @@ import sys import tempfile +import pytest + from rsocket.cli.command import parse_uri, build_composite_metadata, create_request_payload, get_metadata_value, \ - create_setup_payload, normalize_data, normalize_limit_rate + create_setup_payload, normalize_data, normalize_limit_rate, RequestType, get_request_type, parse_headers +from rsocket.extensions.helpers import route, authenticate_simple, authenticate_bearer from rsocket.frame import MAX_REQUEST_N from tests.rsocket.helpers import create_data @@ -25,12 +28,21 @@ def test_parse_uri_wss(): assert parsed.path == 'path' -def test_build_composite_metadata(): - composite = build_composite_metadata( - None, None, None - ) +@pytest.mark.parametrize('route_path, auth_simple, auth_bearer, expected', ( + (None, None, None, []), + ('path1', None, None, [route('path1')]), + ('path1', 'user:pass', None, [route('path1'), authenticate_simple('user', 'pass')]), + ('path1', None, 'token', [route('path1'), authenticate_bearer('token')]), + ('path1', 'user:pass', 'token', Exception), +)) +def test_build_composite_metadata(route_path, auth_simple, auth_bearer, expected): + if isinstance(expected, list): + actual = build_composite_metadata(auth_simple, route_path, auth_bearer) - assert len(composite) == 0 + assert actual == expected + else: + with pytest.raises(expected): + build_composite_metadata(auth_simple, route_path, auth_bearer) def test_create_request_payload(): @@ -100,3 +112,34 @@ def test_normalize_limit_rate(): result = normalize_limit_rate(None) assert result == MAX_REQUEST_N + + +@pytest.mark.parametrize('is_request, stream, fnf, metadata_push, channel, interaction_model, expected', ( + (True, None, None, None, None, None, RequestType.response), + (None, True, None, None, None, None, RequestType.stream), + (None, None, True, None, None, None, RequestType.fnf), + (None, None, None, True, None, None, RequestType.metadata_push), + (None, None, None, None, True, None, RequestType.channel), + (None, None, None, None, None, 'request_channel', RequestType.channel), + (None, None, None, None, True, RequestType.response, Exception), + (None, None, None, True, True, None, Exception), +)) +def test_get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model, expected): + if isinstance(expected, RequestType): + actual = get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model) + + assert actual == expected + else: + with pytest.raises(expected): + get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model) + + +@pytest.mark.parametrize('headers, expected', ( + (None, None), + (['a=b'], {'a': 'b'}), + ([], None), +)) +def test_parse_headers(headers, expected): + actual = parse_headers(headers) + + assert actual == expected From d198ee040715fb4a59dac2ebcf9b9be774ba5dfe Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 20:00:19 +0300 Subject: [PATCH 11/13] cli - added unit tests. raise error if no interaction method specified --- rsocket/cli/command.py | 2 ++ tests/rsocket/test_cli_command.py | 10 +++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index 3b8067ec..ad6e5de8 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -152,6 +152,8 @@ def get_request_type(request: bool, if metadata_push: return RequestType.metadata_push + raise click.UsageError('No interaction method specified (eg. --request)') + @click.command(name='rsocket-py', help='Supported connection strings: tcp/ws/wss') @click.option('--im', '--interactionModel', 'interaction_model', is_flag=False, diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index 38a63292..038e378a 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -108,13 +108,17 @@ def test_normalize_data_from_stdin_takes_precedence_over_load_from_file(): assert data == fixture_data_stdin -def test_normalize_limit_rate(): - result = normalize_limit_rate(None) +@pytest.mark.parametrize('limit_rate, expected', ( + (None, MAX_REQUEST_N), +)) +def test_normalize_limit_rate(limit_rate, expected): + result = normalize_limit_rate(limit_rate) - assert result == MAX_REQUEST_N + assert result == expected @pytest.mark.parametrize('is_request, stream, fnf, metadata_push, channel, interaction_model, expected', ( + (None, None, None, None, None, None, Exception), (True, None, None, None, None, None, RequestType.response), (None, True, None, None, None, None, RequestType.stream), (None, None, True, None, None, None, RequestType.fnf), From 5209fc7f6e1014654d28b2ba0ca05e95c0694366 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 22:26:26 +0300 Subject: [PATCH 12/13] cli - typer hint fixes. added unit tests --- rsocket/cli/command.py | 4 ++-- rsocket/rsocket_client.py | 4 ++-- tests/rsocket/test_cli_command.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index ad6e5de8..fa7de9e5 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -91,7 +91,7 @@ async def transport_from_uri(uri: RSocketUri, def build_composite_metadata(auth_simple: Optional[str], route_value: Optional[str], - auth_bearer: Optional[str]): + auth_bearer: Optional[str]) -> List: composite_items = [] if route_value is not None: @@ -289,7 +289,7 @@ def parse_headers(http_headers): return None -def normalize_metadata_mime_type(composite_items, metadata_mime_type): +def normalize_metadata_mime_type(composite_items: List, metadata_mime_type): if len(composite_items) > 0: metadata_mime_type = WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index 22c19dc0..ef86d469 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -25,8 +25,8 @@ def __init__(self, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, - data_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, - metadata_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + data_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + metadata_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, keep_alive_period: timedelta = timedelta(milliseconds=500), max_lifetime_period: timedelta = timedelta(minutes=10), setup_payload: Optional[Payload] = None, diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index 038e378a..b8f3f683 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -5,8 +5,10 @@ import pytest from rsocket.cli.command import parse_uri, build_composite_metadata, create_request_payload, get_metadata_value, \ - create_setup_payload, normalize_data, normalize_limit_rate, RequestType, get_request_type, parse_headers + create_setup_payload, normalize_data, normalize_limit_rate, RequestType, get_request_type, parse_headers, \ + normalize_metadata_mime_type from rsocket.extensions.helpers import route, authenticate_simple, authenticate_bearer +from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.frame import MAX_REQUEST_N from tests.rsocket.helpers import create_data @@ -147,3 +149,13 @@ def test_parse_headers(headers, expected): actual = parse_headers(headers) assert actual == expected + + +@pytest.mark.parametrize('composite_items, metadata_mime_type, expected', ( + ([], 'application/json', 'application/json'), + ([route('path')], 'application/json', WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA), +)) +def test_normalize_metadata_mime_type(composite_items, metadata_mime_type, expected): + actual = normalize_metadata_mime_type(composite_items, metadata_mime_type) + + assert actual == expected From 08695b1bfd5e6f9a1ab43d9ea8a8f23542fbfeb2 Mon Sep 17 00:00:00 2001 From: Gabriel Shaar Date: Tue, 25 Oct 2022 22:28:37 +0300 Subject: [PATCH 13/13] cli - added unit tests --- tests/rsocket/test_cli_command.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index b8f3f683..beff4d8f 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -10,6 +10,7 @@ from rsocket.extensions.helpers import route, authenticate_simple, authenticate_bearer from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.frame import MAX_REQUEST_N +from rsocket.payload import Payload from tests.rsocket.helpers import create_data @@ -62,10 +63,16 @@ def test_get_metadata_value(): assert result is None -def test_create_setup_payload(): - result = create_setup_payload(None, None) +@pytest.mark.parametrize('data, metadata, expected', ( + (None, None, None), + ('data', None, Payload(b'data')), + ('data', 'metadata', Payload(b'data', b'metadata')), + (None, 'metadata', Payload(None, b'metadata')), +)) +def test_create_setup_payload(data, metadata, expected): + result = create_setup_payload(data, metadata) - assert result is None + assert result == expected def test_normalize_data():