diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 67928416..7ba88b14 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ Changelog --------- +v0.4.2 +====== +- Command line fixes: + - Support passing ssl certificate and http headers when using ws/wss + - Support requesting --version without the need to specify URI arguments + - Option --interactionModel to specify interaction (eg. request_response, request_stream) + - Added Metadata Push support v0.4.1 ====== diff --git a/examples/fixtures.py b/examples/fixtures.py index 977d506a..975e3b7b 100644 --- a/examples/fixtures.py +++ b/examples/fixtures.py @@ -7,46 +7,63 @@ @contextmanager -def cert_gen(emailAddress="emailAddress", - commonName="commonName", - countryName="NT", - localityName="localityName", - stateOrProvinceName="stateOrProvinceName", - organizationName="organizationName", - organizationUnitName="organizationUnitName", - serialNumber=0, - validityStartInSeconds=0, - validityEndInSeconds=None) -> Tuple[str, str]: - if validityEndInSeconds is None: - validityEndInSeconds = int(timedelta(days=3650).total_seconds()) +def generate_certificate_and_key(email_address="emailAddress", + common_name="localhost", + country_name="NT", + locality_name="localityName", + state_or_province_name="stateOrProvinceName", + organization_name="organizationName", + organization_unit_name="organizationUnitName", + serial_number=0, + validity_start_in_seconds=0, + validity_end_in_seconds=None) -> Tuple[str, str]: + if validity_end_in_seconds is None: + validity_end_in_seconds = int(timedelta(days=3650).total_seconds()) + # can look at generated file using openssl: # openssl x509 -inform pem -in selfsigned.crt -noout -text # create a key pair - k = crypto.PKey() - k.generate_key(crypto.TYPE_RSA, 4096) + private_key = create_key() # create a self-signed cert - cert = crypto.X509() - cert.get_subject().C = countryName - cert.get_subject().ST = stateOrProvinceName - cert.get_subject().L = localityName - cert.get_subject().O = organizationName - cert.get_subject().OU = organizationUnitName - cert.get_subject().CN = commonName - cert.get_subject().emailAddress = emailAddress - cert.set_serial_number(serialNumber) - cert.gmtime_adj_notBefore(0) - cert.gmtime_adj_notAfter(validityEndInSeconds) - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(k) - cert.sign(k, 'sha512') + cert = create_self_signed_certificate(common_name, country_name, email_address, private_key, locality_name, + organization_name, + organization_unit_name, serial_number, state_or_province_name, + validity_end_in_seconds, validity_start_in_seconds) with tempfile.NamedTemporaryFile() as certificate_file: with tempfile.NamedTemporaryFile() as key_file: certificate_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) certificate_file.flush() - key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)) + key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, private_key)) key_file.flush() yield certificate_file.name, key_file.name + + +def create_key(): + k = crypto.PKey() + k.generate_key(crypto.TYPE_RSA, 4096) + return k + + +def create_self_signed_certificate(common_name, country_name, email_address, private_key, locality_name, + organization_name, + organization_unit_name, serial_number, state_or_province_name, + validity_end_in_seconds, validity_start_in_seconds): + cert = crypto.X509() + cert.get_subject().C = country_name + cert.get_subject().ST = state_or_province_name + cert.get_subject().L = locality_name + cert.get_subject().O = organization_name + cert.get_subject().OU = organization_unit_name + cert.get_subject().CN = common_name + cert.get_subject().emailAddress = email_address + cert.set_serial_number(serial_number) + cert.gmtime_adj_notBefore(validity_start_in_seconds) + cert.gmtime_adj_notAfter(validity_end_in_seconds) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(private_key) + cert.sign(private_key, 'sha512') + return cert diff --git a/examples/server_aiohttp_websocket.py b/examples/server_aiohttp_websocket.py index d976638d..d62139a1 100644 --- a/examples/server_aiohttp_websocket.py +++ b/examples/server_aiohttp_websocket.py @@ -4,7 +4,7 @@ import asyncclick as click from aiohttp import web -from examples.fixtures import cert_gen +from examples.fixtures import generate_certificate_and_key from rsocket.helpers import create_future from rsocket.local_typing import Awaitable from rsocket.payload import Payload @@ -42,7 +42,7 @@ async def start_server(with_ssl: bool, port: int): if with_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - with cert_gen() as (certificate, key): + with generate_certificate_and_key() as (certificate, key): ssl_context.load_cert_chain(certificate, key) else: ssl_context = None diff --git a/examples/server_with_routing.py b/examples/server_with_routing.py index aeb406eb..4762d6d0 100644 --- a/examples/server_with_routing.py +++ b/examples/server_with_routing.py @@ -9,7 +9,7 @@ from aiohttp import web from examples.example_fixtures import large_data1 -from examples.fixtures import cert_gen +from examples.fixtures import generate_certificate_and_key from examples.response_channel import response_stream_1, LoggingSubscriber from response_stream import response_stream_2 from rsocket.extensions.authentication import Authentication, AuthenticationSimple @@ -136,15 +136,18 @@ async def start_server(with_ssl: bool, port: int, transport: str): app = web.Application() app.add_routes([web.get('/', websocket_handler_factory(handler_factory=handler_factory))]) - if with_ssl: - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + with generate_certificate_and_key() as (certificate_path, key_path): + if with_ssl: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - with cert_gen() as (certificate, key): - ssl_context.load_cert_chain(certificate, key) - else: - ssl_context = None + logging.info('Certificate %s', certificate_path) + logging.info('Private-key %s', key_path) - await web._run_app(app, port=port, ssl_context=ssl_context) + ssl_context.load_cert_chain(certificate_path, key_path) + else: + ssl_context = None + + await web._run_app(app, port=port, ssl_context=ssl_context) elif transport == 'tcp': server = await asyncio.start_server(handle_client, 'localhost', port) diff --git a/rsocket/cli/command.py b/rsocket/cli/command.py index b87072c1..fa7de9e5 100644 --- a/rsocket/cli/command.py +++ b/rsocket/cli/command.py @@ -1,8 +1,11 @@ import asyncio import logging +import ssl from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Optional, Type, Collection, List +from enum import Enum, unique +from importlib.metadata import version as get_version +from typing import Optional, Type, Collection, List, Callable import aiohttp import asyncclick as click @@ -19,13 +22,24 @@ from rsocket.transports.abstract_messaging import AbstractMessagingTransport from rsocket.transports.aiohttp_websocket import TransportAioHttpClient from rsocket.transports.tcp import TransportTCP -from importlib.metadata import version as get_version + + +@unique +class RequestType(Enum): + response = 'REQUEST_RESPONSE' + stream = 'REQUEST_STREAM' + channel = 'REQUEST_CHANNEL' + fnf = 'FIRE_AND_FORGET' + metadata_push = 'METADATA_PUSH' + + +interaction_models: List[str] = [str(e.value) for e in RequestType] @dataclass(frozen=True) class RSocketUri: host: str - port: str + port: int schema: str path: Optional[str] = None original_uri: Optional[str] = None @@ -34,11 +48,17 @@ class RSocketUri: def parse_uri(uri: str) -> RSocketUri: schema, rest = uri.split(':', 1) rest = rest.strip('/') - host_port = rest.split('/', 1) - host, port = host_port[0].split(':') + host_port_path = rest.split('/', 1) + host_port = host_port_path[0].split(':') + host = host_port[0] - if len(host_port) > 1: - rest = host_port[1] + if len(host_port) == 1: + port = None + else: + port = int(host_port[1]) + + if len(host_port_path) > 1: + rest = host_port_path[1] else: rest = None @@ -48,14 +68,21 @@ def parse_uri(uri: str) -> RSocketUri: @asynccontextmanager async def transport_from_uri(uri: RSocketUri, verify_ssl=True, - headers: Optional[Map] = None) -> Type[AbstractMessagingTransport]: + headers: Optional[Map] = None, + trust_cert: Optional[str] = None) -> Type[AbstractMessagingTransport]: if uri.schema == 'tcp': connection = await asyncio.open_connection(uri.host, uri.port) yield TransportTCP(*connection) elif uri.schema in ['wss', 'ws']: async with aiohttp.ClientSession() as session: + if trust_cert is not None: + ssl_context = ssl.create_default_context(cafile=trust_cert) + else: + ssl_context = None + async with session.ws_connect(uri.original_uri, verify_ssl=verify_ssl, + ssl_context=ssl_context, headers=headers) as websocket: yield TransportAioHttpClient(websocket=websocket) else: @@ -64,12 +91,15 @@ async def transport_from_uri(uri: RSocketUri, def build_composite_metadata(auth_simple: Optional[str], route_value: Optional[str], - auth_bearer: Optional[str]): + auth_bearer: Optional[str]) -> List: composite_items = [] if route_value is not None: composite_items.append(route(route_value)) + if auth_simple is not None and auth_bearer is not None: + raise click.UsageError('Multiple authentication methods specified.') + if auth_simple is not None: composite_items.append(authenticate_simple(*auth_simple.split(':'))) @@ -85,8 +115,12 @@ async def create_client(parsed_uri, metadata_mime_type, setup_payload, allow_untrusted_ssl=False, - http_headers=None): - async with transport_from_uri(parsed_uri, verify_ssl=not allow_untrusted_ssl, headers=http_headers) as transport: + http_headers=None, + trust_cert=None): + async with transport_from_uri(parsed_uri, + verify_ssl=not allow_untrusted_ssl, + headers=http_headers, + trust_cert=trust_cert) as transport: async with RSocketClient(single_transport_provider(transport), data_encoding=data_mime_type or WellKnownMimeTypes.APPLICATION_JSON, metadata_encoding=metadata_mime_type or WellKnownMimeTypes.APPLICATION_JSON, @@ -94,7 +128,37 @@ async def create_client(parsed_uri, yield AwaitableRSocket(client) +def get_request_type(request: bool, + stream: bool, + fnf: bool, + metadata_push: bool, + channel: bool, + interaction_model: Optional[str]) -> RequestType: + interaction_options = list(filter(lambda _: _ is True, [request, stream, fnf, channel, metadata_push])) + + if len(interaction_options) >= 2 or (len(interaction_options) >= 1 and interaction_model is not None): + raise click.UsageError('Multiple interaction methods specified.') + + if interaction_model is not None: + return RequestType(interaction_model.upper()) + if request: + return RequestType.response + if stream: + return RequestType.stream + if channel: + return RequestType.channel + if fnf: + return RequestType.fnf + if metadata_push: + return RequestType.metadata_push + + raise click.UsageError('No interaction method specified (eg. --request)') + + @click.command(name='rsocket-py', help='Supported connection strings: tcp/ws/wss') +@click.option('--im', '--interactionModel', 'interaction_model', is_flag=False, + type=click.Choice(interaction_models, case_sensitive=False), + help='Interaction Model') @click.option('--request', is_flag=True, help='Request response') @click.option('--stream', is_flag=True, @@ -103,7 +167,9 @@ async def create_client(parsed_uri, help='Request channel') @click.option('--fnf', is_flag=True, help='Fire and Forget') -@click.option('-d', '--data', is_flag=False, +@click.option('--metadataPush', 'metadata_push', is_flag=True, + help='Metadata Push') +@click.option('-d', '--data', '--input', 'data', is_flag=False, help='Data. Use "-" to read data from standard input. (default: )') @click.option('-l', '--load', is_flag=False, help='Load a file as Data. (e.g. ./foo.txt, /tmp/foo.txt)') @@ -117,34 +183,39 @@ async def create_client(parsed_uri, help='Enable take(n)') @click.option('-u', '--as', '--authSimple', 'auth_simple', is_flag=False, default=None, help='Enable Authentication Metadata Extension (Simple). The format must be "username: password"') -@click.option('--sd', '--setupData', 'setup_data', is_flag=False, default=None, +@click.option('--sd', '--setup', '--setupData', 'setup_data', is_flag=False, default=None, help='Data for Setup payload') @click.option('--sm', '--setupMetadata', 'setup_metadata', is_flag=False, default=None, help='Metadata for Setup payload') @click.option('--ab', '--authBearer', 'auth_bearer', is_flag=False, default=None, help='Enable Authentication Metadata Extension (Bearer)') -@click.option('--dataMimeType', '--dmt', 'data_mime_type', is_flag=False, +@click.option('--dataMimeType', '--dataFormat', '--dmt', 'data_mime_type', is_flag=False, help='MimeType for data (default: application/json)') -@click.option('--metadataMimeType', '--mmt', 'metadata_mime_type', is_flag=False, +@click.option('--metadataMimeType', '--metadataFormat', '--mmt', 'metadata_mime_type', is_flag=False, help='MimeType for metadata (default:application/json)') @click.option('--allowUntrustedSsl', 'allow_untrusted_ssl', is_flag=True, default=False, help='Do not verify SSL certificate (for wss:// urls)') -@click.option('--httpHeader', 'http_header', multiple=True, +@click.option('-H', '--header', '--httpHeader', 'http_header', multiple=True, help='ws/wss headers') +@click.option('--trustCert', 'trust_cert', is_flag=False, + help='PEM file for a trusted certificate. (e.g. ./foo.crt, /tmp/foo.crt)') @click.option('--debug', is_flag=True, help='Show debug log') @click.option('--quiet', '-q', is_flag=True, help='Disable the output on next') +@click.option('--timeout', 'timeout_seconds', is_flag=False, type=int, + help='Timeout in seconds') @click.option('--version', is_flag=True, help='Print version') -@click.argument('uri') -async def command(data, load, +@click.argument('uri', required=False) +@click.pass_context +async def command(context, data, load, metadata, route_value, auth_simple, auth_bearer, limit_rate, take_n, allow_untrusted_ssl, - setup_data, setup_metadata, - http_header, + setup_data, setup_metadata, interaction_model, + http_header, metadata_push, timeout_seconds, data_mime_type, metadata_mime_type, - request, stream, channel, fnf, + request, stream, channel, fnf, trust_cert, uri, debug, version, quiet): if version: try: @@ -153,6 +224,9 @@ async def command(data, load, print('Failed to find version') return + if uri is None: + raise click.MissingParameter(param=context.command.params[-1]) + if quiet: logging.basicConfig(handlers=[]) @@ -162,28 +236,44 @@ async def command(data, load, if take_n == 0: return + request_type = get_request_type(request, stream, fnf, metadata_push, channel, interaction_model) http_headers = parse_headers(http_header) - composite_items = build_composite_metadata(auth_simple, route_value, auth_bearer) + setup_payload = create_setup_payload(setup_data, setup_metadata) + metadata_value = get_metadata_value(composite_items, metadata) + metadata_mime_type = normalize_metadata_mime_type(composite_items, metadata_mime_type) + parsed_uri = parse_uri(uri) + + def payload_provider(): + return create_request_payload(data, load, metadata_value) + + future = run_request(request_type, limit_rate, payload_provider, + http_headers=http_headers, + allow_untrusted_ssl=allow_untrusted_ssl, + metadata_mime_type=metadata_mime_type, + data_mime_type=data_mime_type, + setup_payload=setup_payload, + trust_cert=trust_cert, + parsed_uri=parsed_uri) + + if timeout_seconds is not None: + result = await asyncio.wait_for(future, timeout_seconds) + else: + result = await future - async with create_client(parse_uri(uri), - data_mime_type, - normalize_metadata_mime_type(composite_items, metadata_mime_type), - create_setup_payload(setup_data, setup_metadata), - allow_untrusted_ssl=allow_untrusted_ssl, - http_headers=http_headers - ) as client: + if not quiet: + output_result(result) - result = await execute_request(client, - channel, - fnf, - normalize_limit_rate(limit_rate), - create_request_payload(data, load, metadata, composite_items), - request, - stream) - if not quiet: - output_result(result) +async def run_request(request_type: RequestType, + limit_rate: Optional[int], + payload_provider: Callable[[], Payload], + **kwargs): + async with create_client(**kwargs) as client: + return await execute_request(client, + request_type, + normalize_limit_rate(limit_rate), + payload_provider()) def parse_headers(http_headers): @@ -199,7 +289,7 @@ def parse_headers(http_headers): return None -def normalize_metadata_mime_type(composite_items, metadata_mime_type): +def normalize_metadata_mime_type(composite_items: List, metadata_mime_type): if len(composite_items) > 0: metadata_mime_type = WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA @@ -208,11 +298,10 @@ def normalize_metadata_mime_type(composite_items, metadata_mime_type): def create_request_payload(data: Optional[str], load: Optional[str], - metadata: Optional[str], - composite_items: List) -> Payload: + metadata: Optional[bytes]) -> Payload: data = normalize_data(data, load) - metadata_value = get_metadata_value(composite_items, metadata) - return Payload(data, metadata_value) + + return Payload(data, metadata) def output_result(result): @@ -222,17 +311,22 @@ def output_result(result): print([p.data.decode('utf-8') for p in result]) -async def execute_request(awaitable_client, channel, fnf, limit_rate, payload, request, stream): +async def execute_request(awaitable_client: AwaitableRSocket, + request_type: RequestType, + limit_rate: int, + payload: Payload): result = None - if request: + if request_type is RequestType.response: result = await awaitable_client.request_response(payload) - elif stream: + elif request_type is RequestType.stream: result = await awaitable_client.request_stream(payload, limit_rate=limit_rate) - elif channel: + elif request_type is RequestType.channel: result = await awaitable_client.request_channel(payload, limit_rate=limit_rate) - elif fnf: + elif request_type is RequestType.fnf: await awaitable_client.fire_and_forget(payload) + elif request_type is RequestType.metadata_push: + await awaitable_client.metadata_push(payload.metadata) return result diff --git a/rsocket/rsocket_client.py b/rsocket/rsocket_client.py index 22c19dc0..ef86d469 100644 --- a/rsocket/rsocket_client.py +++ b/rsocket/rsocket_client.py @@ -25,8 +25,8 @@ def __init__(self, honor_lease=False, lease_publisher: Optional[Publisher] = None, request_queue_size: int = 0, - data_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, - metadata_encoding: Union[bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + data_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, + metadata_encoding: Union[str, bytes, WellKnownMimeTypes] = WellKnownMimeTypes.APPLICATION_JSON, keep_alive_period: timedelta = timedelta(milliseconds=500), max_lifetime_period: timedelta = timedelta(minutes=10), setup_payload: Optional[Payload] = None, diff --git a/setup.py b/setup.py index f2c1d108..d2612169 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='rsocket', - version='0.4.1', + version='0.4.2', description='Python RSocket library', long_description=long_description, long_description_content_type='text/markdown', diff --git a/tests/rsocket/test_cli_command.py b/tests/rsocket/test_cli_command.py index f075f116..beff4d8f 100644 --- a/tests/rsocket/test_cli_command.py +++ b/tests/rsocket/test_cli_command.py @@ -2,31 +2,55 @@ import sys import tempfile +import pytest + from rsocket.cli.command import parse_uri, build_composite_metadata, create_request_payload, get_metadata_value, \ - create_setup_payload, normalize_data, normalize_limit_rate + create_setup_payload, normalize_data, normalize_limit_rate, RequestType, get_request_type, parse_headers, \ + normalize_metadata_mime_type +from rsocket.extensions.helpers import route, authenticate_simple, authenticate_bearer +from rsocket.extensions.mimetypes import WellKnownMimeTypes from rsocket.frame import MAX_REQUEST_N +from rsocket.payload import Payload from tests.rsocket.helpers import create_data def test_parse_uri(): - parsed = parse_uri('wss://localhost:6565') + parsed = parse_uri('tcp://localhost:6565') + + assert parsed.schema == 'tcp' + assert parsed.port == 6565 + assert parsed.host == 'localhost' + + +def test_parse_uri_wss(): + parsed = parse_uri('wss://localhost/path') assert parsed.schema == 'wss' - assert parsed.port == '6565' + assert parsed.port is None assert parsed.host == 'localhost' + assert parsed.path == 'path' -def test_build_composite_metadata(): - composite = build_composite_metadata( - None, None, None - ) +@pytest.mark.parametrize('route_path, auth_simple, auth_bearer, expected', ( + (None, None, None, []), + ('path1', None, None, [route('path1')]), + ('path1', 'user:pass', None, [route('path1'), authenticate_simple('user', 'pass')]), + ('path1', None, 'token', [route('path1'), authenticate_bearer('token')]), + ('path1', 'user:pass', 'token', Exception), +)) +def test_build_composite_metadata(route_path, auth_simple, auth_bearer, expected): + if isinstance(expected, list): + actual = build_composite_metadata(auth_simple, route_path, auth_bearer) - assert len(composite) == 0 + assert actual == expected + else: + with pytest.raises(expected): + build_composite_metadata(auth_simple, route_path, auth_bearer) def test_create_request_payload(): payload = create_request_payload( - None, None, None, [] + None, None, None ) assert payload.data is None @@ -39,10 +63,16 @@ def test_get_metadata_value(): assert result is None -def test_create_setup_payload(): - result = create_setup_payload(None, None) +@pytest.mark.parametrize('data, metadata, expected', ( + (None, None, None), + ('data', None, Payload(b'data')), + ('data', 'metadata', Payload(b'data', b'metadata')), + (None, 'metadata', Payload(None, b'metadata')), +)) +def test_create_setup_payload(data, metadata, expected): + result = create_setup_payload(data, metadata) - assert result is None + assert result == expected def test_normalize_data(): @@ -87,7 +117,52 @@ def test_normalize_data_from_stdin_takes_precedence_over_load_from_file(): assert data == fixture_data_stdin -def test_normalize_limit_rate(): - result = normalize_limit_rate(None) +@pytest.mark.parametrize('limit_rate, expected', ( + (None, MAX_REQUEST_N), +)) +def test_normalize_limit_rate(limit_rate, expected): + result = normalize_limit_rate(limit_rate) + + assert result == expected + + +@pytest.mark.parametrize('is_request, stream, fnf, metadata_push, channel, interaction_model, expected', ( + (None, None, None, None, None, None, Exception), + (True, None, None, None, None, None, RequestType.response), + (None, True, None, None, None, None, RequestType.stream), + (None, None, True, None, None, None, RequestType.fnf), + (None, None, None, True, None, None, RequestType.metadata_push), + (None, None, None, None, True, None, RequestType.channel), + (None, None, None, None, None, 'request_channel', RequestType.channel), + (None, None, None, None, True, RequestType.response, Exception), + (None, None, None, True, True, None, Exception), +)) +def test_get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model, expected): + if isinstance(expected, RequestType): + actual = get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model) + + assert actual == expected + else: + with pytest.raises(expected): + get_request_type(is_request, stream, fnf, metadata_push, channel, interaction_model) + + +@pytest.mark.parametrize('headers, expected', ( + (None, None), + (['a=b'], {'a': 'b'}), + ([], None), +)) +def test_parse_headers(headers, expected): + actual = parse_headers(headers) + + assert actual == expected + + +@pytest.mark.parametrize('composite_items, metadata_mime_type, expected', ( + ([], 'application/json', 'application/json'), + ([route('path')], 'application/json', WellKnownMimeTypes.MESSAGE_RSOCKET_COMPOSITE_METADATA), +)) +def test_normalize_metadata_mime_type(composite_items, metadata_mime_type, expected): + actual = normalize_metadata_mime_type(composite_items, metadata_mime_type) - assert result == MAX_REQUEST_N + assert actual == expected