diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 507cfbd..53e95d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: matrix: os: [ubuntu-latest] python-version: ['3.6', '3.7', '3.8'] # 'pypy-3.6' - backend: [asyncio, curio, trio, uvloop] + backend: [asyncio, trio, uvloop] steps: - uses: actions/checkout@v3 - name: Setup Python @@ -29,7 +29,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] - backend: [asyncio, curio, trio, uvloop] + backend: [asyncio, trio, uvloop] steps: - uses: actions/checkout@v3 - name: Setup Python diff --git a/README.md b/README.md index 6384629..3041330 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ Asynchronous pure Python gRPC client and server implementation supporting [asyncio](https://docs.python.org/3/library/asyncio.html), [uvloop](https://github.com/MagicStack/uvloop), -[curio](https://github.com/dabeaz/curio) and [trio](https://github.com/python-trio/trio) (achieved with [anyio](https://github.com/agronholm/anyio) compatibility layer). ## Requirements @@ -29,7 +28,7 @@ Latest development version: pip install git+https://github.com/standy66/purerpc.git ``` -By default purerpc uses asyncio event loop, if you want to use uvloop, curio or trio, please install them manually. +By default purerpc uses asyncio event loop, if you want to use uvloop or trio, please install them manually. ## protoc plugin @@ -57,9 +56,9 @@ Just mark yielding coroutines with `@async_generator` decorator and use `await y ### Server ```python +from purerpc import Server from greeter_pb2 import HelloRequest, HelloReply from greeter_grpc import GreeterServicer -from purerpc import Server class Greeter(GreeterServicer): @@ -71,15 +70,17 @@ class Greeter(GreeterServicer): yield HelloReply(message=f"Hello, {message.name}") -server = Server(50055) -server.add_service(Greeter().service) -server.serve(backend="asyncio") # backend can also be one of: "uvloop", "curio", "trio" +if __name__ == '__main__': + server = Server(50055) + server.add_service(Greeter().service) + # NOTE: if you already have an async loop running, use "await server.serve_async()" + import anyio + anyio.run(server.serve_async) # or set explicit backend="asyncio" or "trio" ``` ### Client ```python -import anyio import purerpc from greeter_pb2 import HelloRequest, HelloReply from greeter_grpc import GreeterStub @@ -90,7 +91,7 @@ async def gen(): yield HelloRequest(name=str(i)) -async def main(): +async def listen(): async with purerpc.insecure_channel("localhost", 50055) as channel: stub = GreeterStub(channel) reply = await stub.SayHello(HelloRequest(name="World")) @@ -100,8 +101,10 @@ async def main(): print(reply.message) -if __name__ == "__main__": - anyio.run(main, backend="asyncio") # backend can also be one of: "uvloop", "curio", "trio" +if __name__ == '__main__': + # NOTE: if you already have an async loop running, use "await listen()" + import anyio + anyio.run(listen) # or set explicit backend="asyncio" or "trio" ``` You can mix server and client code, for example make a server that requests something using purerpc from another gRPC server, etc. diff --git a/misc/greeter/client.py b/misc/greeter/client.py index 3925a96..73cc4e1 100644 --- a/misc/greeter/client.py +++ b/misc/greeter/client.py @@ -16,13 +16,12 @@ async def worker(channel): async def main_coro(): - # await curio.spawn(print_memory_growth_statistics(), daemon=True) async with purerpc.insecure_channel("localhost", 50055) as channel: for _ in range(100): start = time.time() async with anyio.create_task_group() as task_group: for _ in range(100): - await task_group.spawn(worker, channel) + task_group.start_soon(worker, channel) print("RPS: {}".format(10000 / (time.time() - start))) diff --git a/misc/greeter/failing_client.py b/misc/greeter/failing_client.py index 8a89008..cfac2b6 100644 --- a/misc/greeter/failing_client.py +++ b/misc/greeter/failing_client.py @@ -16,7 +16,7 @@ async def do_load_unary(result_queue, stub, num_requests, message_size): result = (await stub.SayHello(greeter_pb2.HelloRequest(name=message))).message assert (len(result) == message_size) avg_latency = (time.time() - start) / num_requests - await result_queue.put(avg_latency) + await result_queue.send(avg_latency) async def do_load_stream(result_queue, stub, num_requests, message_size): @@ -30,7 +30,7 @@ async def do_load_stream(result_queue, stub, num_requests, message_size): avg_latency = (time.time() - start) / num_requests await stream.close() await stream.receive_message() - await result_queue.put(avg_latency) + await result_queue.send(avg_latency) async def worker(port, num_concurrent_streams, num_requests_per_stream, @@ -45,17 +45,17 @@ async def worker(port, num_concurrent_streams, num_requests_per_stream, raise ValueError(f"Unknown load type: {load_type}") for idx in range(num_rounds): start = time.time() - task_results = anyio.create_queue(sys.maxsize) + send_queue, receive_queue = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) async with anyio.create_task_group() as task_group: for _ in range(num_concurrent_streams): - await task_group.spawn(load_fn, task_results, stub, num_requests_per_stream, message_size) + task_group.start_soon(load_fn, send_queue, stub, num_requests_per_stream, message_size) end = time.time() rps = num_concurrent_streams * num_requests_per_stream / (end - start) latencies = [] for _ in range(num_concurrent_streams): - latencies.append(await task_results.get()) + latencies.append(await receive_queue.receive()) print("Round", idx, "rps", rps, "avg latency", 1000 * sum(latencies) / len(latencies)) diff --git a/misc/greeter/test_perf.py b/misc/greeter/test_perf.py index 70c758a..4dedd3a 100644 --- a/misc/greeter/test_perf.py +++ b/misc/greeter/test_perf.py @@ -29,7 +29,7 @@ async def do_load_unary(result_queue, stub, num_requests, message_size): result = (await stub.SayHello(HelloRequest(name=message))).message assert (len(result) == message_size) avg_latency = (time.time() - start) / num_requests - await result_queue.put(avg_latency) + await result_queue.send(avg_latency) async def do_load_stream(result_queue, stub, num_requests, message_size): @@ -43,7 +43,7 @@ async def do_load_stream(result_queue, stub, num_requests, message_size): avg_latency = (time.time() - start) / num_requests await stream.close() await stream.receive_message() - await result_queue.put(avg_latency) + await result_queue.send(avg_latency) async def worker(port, queue, num_concurrent_streams, num_requests_per_stream, @@ -58,16 +58,16 @@ async def worker(port, queue, num_concurrent_streams, num_requests_per_stream, raise ValueError(f"Unknown load type: {load_type}") for _ in range(num_rounds): start = time.time() - task_results = anyio.create_queue(sys.maxsize) + send_queue, receive_queue = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) async with anyio.create_task_group() as task_group: for _ in range(num_concurrent_streams): - await task_group.spawn(load_fn, task_results, stub, num_requests_per_stream, message_size) + task_group.start_soon(load_fn, send_queue, stub, num_requests_per_stream, message_size) end = time.time() rps = num_concurrent_streams * num_requests_per_stream / (end - start) queue.put(rps) results = [] for _ in range(num_concurrent_streams): - results.append(await task_results.get()) + results.append(await receive_queue.receive()) queue.put(results) queue.close() queue.join_thread() diff --git a/requirements_test.txt b/requirements_test.txt index 1432c2a..2fded70 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -4,13 +4,12 @@ # # pip-compile --extra=test --output-file=requirements_test.txt setup.py # -anyio==1.4.0 +anyio==3.5.0 # via purerpc (setup.py) async-exit-stack==1.0.1 # via purerpc (setup.py) async-generator==1.10 # via - # anyio # purerpc (setup.py) # trio attrs==21.4.0 @@ -18,8 +17,10 @@ attrs==21.4.0 # outcome # pytest # trio -curio==1.5 - # via purerpc (setup.py) +cffi==1.15.0 + # via cryptography +cryptography==36.0.2 + # via trustme grpcio==1.44.0 # via # grpcio-tools @@ -36,6 +37,7 @@ idna==3.3 # via # anyio # trio + # trustme importlib-metadata==4.11.3 # via # pluggy @@ -54,6 +56,8 @@ protobuf==3.20.0 # purerpc (setup.py) py==1.11.0 # via pytest +pycparser==2.21 + # via cffi pyparsing==3.0.8 # via packaging pytest==7.1.1 @@ -74,8 +78,12 @@ tomli==2.0.1 # via pytest trio==0.20.0 # via purerpc (setup.py) +trustme==0.9.0 + # via purerpc (setup.py) typing-extensions==4.1.1 - # via importlib-metadata + # via + # anyio + # importlib-metadata uvloop==0.16.0 # via purerpc (setup.py) zipp==3.8.0 diff --git a/requirements_test_py36.txt b/requirements_test_py36.txt index 483a220..868a424 100644 --- a/requirements_test_py36.txt +++ b/requirements_test_py36.txt @@ -4,13 +4,12 @@ # # pip-compile --output-file=requirements_test_py36.txt requirements_test.in setup.py # -anyio==1.4.0 +anyio==3.5.0 # via purerpc (setup.py) async-exit-stack==1.0.1 # via purerpc (setup.py) async-generator==1.10 # via - # anyio # purerpc (setup.py) # trio attrs==21.4.0 @@ -18,12 +17,17 @@ attrs==21.4.0 # outcome # pytest # trio +cffi==1.15.0 + # via cryptography contextvars==2.4 # via + # anyio # sniffio # trio -curio==1.4 - # via -r requirements_test.in +cryptography==36.0.2 + # via trustme +dataclasses==0.8 + # via anyio grpcio==1.44.0 # via # -r requirements_test.in @@ -40,6 +44,7 @@ idna==3.3 # via # anyio # trio + # trustme immutables==0.17 # via contextvars importlib-metadata==4.8.3 @@ -60,6 +65,8 @@ protobuf==3.19.4 # purerpc (setup.py) py==1.11.0 # via pytest +pycparser==2.21 + # via cffi pyparsing==3.0.8 # via packaging pytest==7.0.1 @@ -80,8 +87,11 @@ tomli==1.2.3 # via pytest trio==0.19.0 # via -r requirements_test.in +trustme==0.9.0 + # via -r requirements_test.in typing-extensions==4.1.1 # via + # anyio # immutables # importlib-metadata uvloop==0.14.0 diff --git a/setup.py b/setup.py index 5ae4f8c..4e7a3fd 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def main(): version=__version__, license="Apache License Version 2.0", description=("Asynchronous pure Python gRPC client and server implementation " - "supporting asyncio, uvloop, curio and trio"), + "supporting asyncio, uvloop, trio"), long_description='%s\n%s' % ( re.compile('^.. start-badges.*^.. end-badges', re.M | re.S).sub('', read('README.md')), re.sub(':[a-z]+:`~?(.*?)`', r'``\1``', read('RELEASE.md')) @@ -64,7 +64,7 @@ def main(): install_requires=[ "h2>=3.1.0,<4", "protobuf>=3.5.1", - "anyio>=1.0.0,<2", # TODO: anyio 3.x upgrade + "anyio>=3.0.0", "async_exit_stack>=1.0.1", "tblib>=1.3.2", "async_generator>=1.10", @@ -79,8 +79,8 @@ def main(): "grpcio_tools>=1.25.0", # same here "uvloop", "trio>=0.11", - "curio>=0.9", "python-forge>=18.6", + "trustme", ] }, ) diff --git a/src/purerpc/client.py b/src/purerpc/client.py index 63761f5..bfdc222 100644 --- a/src/purerpc/client.py +++ b/src/purerpc/client.py @@ -15,14 +15,14 @@ def __init__(self, host, port, ssl_context=None): super().__init__() self._host = host self._port = port - self._ssl = ssl_context + self._ssl_context = ssl_context self._grpc_socket = None async def __aenter__(self): await super().__aenter__() # Does nothing socket = await anyio.connect_tcp(self._host, self._port, - ssl_context=self._ssl, - autostart_tls=self._ssl is not None, + ssl_context=self._ssl_context, + tls=self._ssl_context is not None, tls_standard_compatible=False) config = GRPCConfiguration(client_side=True) self._grpc_socket = await self.enter_async_context(GRPCProtoSocket(config, socket)) diff --git a/src/purerpc/grpc_socket.py b/src/purerpc/grpc_socket.py index 313e4a5..25ef40b 100644 --- a/src/purerpc/grpc_socket.py +++ b/src/purerpc/grpc_socket.py @@ -2,8 +2,10 @@ import enum import socket import datetime +from typing import Dict import anyio +import anyio.abc import async_exit_stack from async_generator import async_generator, yield_, yield_from_ from purerpc.utils import is_darwin, is_windows @@ -16,28 +18,29 @@ class SocketWrapper(async_exit_stack.AsyncExitStack): - def __init__(self, grpc_connection: GRPCConnection, sock: anyio.SocketStream): + def __init__(self, grpc_connection: GRPCConnection, stream: anyio.abc.SocketStream): super().__init__() - self._set_socket_options(sock) - self._socket = sock + self._set_socket_options(stream) + self._stream = stream self._grpc_connection = grpc_connection - self._flush_event = anyio.create_event() + self._flush_event = anyio.Event() self._running = True async def __aenter__(self): await super().__aenter__() task_group = await self.enter_async_context(anyio.create_task_group()) - await task_group.spawn(self._writer_thread) + task_group.start_soon(self._writer_thread) async def callback(): self._running = False - await self._flush_event.set() + self._flush_event.set() self.push_async_callback(callback) return self @staticmethod - def _set_socket_options(sock: anyio.SocketStream): + def _set_socket_options(stream: anyio.abc.SocketStream): + sock = stream.extra(anyio.abc.SocketAttribute.raw_socket) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if hasattr(socket, "TCP_KEEPIDLE"): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 300) @@ -54,20 +57,20 @@ async def _writer_thread(self): while True: data = self._grpc_connection.data_to_send() if data: - await self._socket.send_all(data) + await self._stream.send(data) elif self._running: await self._flush_event.wait() - self._flush_event.clear() + self._flush_event = anyio.Event() else: return async def flush(self): """This maybe called from different threads.""" - await self._flush_event.set() + self._flush_event.set() async def recv(self, buffer_size: int): """This may only be called from single thread.""" - return await self._socket.receive_some(buffer_size) + return await self._stream.receive(buffer_size) class GRPCStreamState(enum.Enum): @@ -84,8 +87,9 @@ def __init__(self, grpc_connection: GRPCConnection, stream_id: int, socket: Sock self._grpc_connection = grpc_connection self._grpc_socket = grpc_socket self._socket = socket - self._flow_control_update_event = anyio.create_event() - self._incoming_events = anyio.create_queue(sys.maxsize) + self._flow_control_update_event = anyio.Event() + # TODO: find a reasonable buffer size, or expose it in the API + self._incoming_events = anyio.create_memory_object_stream(max_buffer_size=sys.maxsize) # (send, receive) self._response_started = False self._state = GRPCStreamState.OPEN self._start_stream_event = None @@ -130,11 +134,11 @@ def _close_local(self): del self._grpc_socket._streams[self._stream_id] async def _set_flow_control_update(self): - await self._flow_control_update_event.set() + self._flow_control_update_event.set() async def _wait_flow_control_update(self): await self._flow_control_update_event.wait() - self._flow_control_update_event.clear() + self._flow_control_update_event = anyio.Event() async def _send(self, message: bytes, compress=False): message_write_buffer = MessageWriteBuffer(self._grpc_connection.config.message_encoding, @@ -151,7 +155,7 @@ async def _send(self, message: bytes, compress=False): await self._socket.flush() async def _receive(self): - event = await self._incoming_events.get() + event = await self._incoming_events[1].receive() if isinstance(event, MessageReceived): self._grpc_connection.acknowledge_received_data(self._stream_id, event.flow_controlled_length) @@ -211,8 +215,8 @@ async def __aenter__(self): await self._socket.flush() if self.client_side: task_group = await self.enter_async_context(anyio.create_task_group()) - self.push_async_callback(task_group.cancel_scope.cancel) - await task_group.spawn(self._reader_thread) + self.callback(task_group.cancel_scope.cancel) + task_group.start_soon(self._reader_thread) return self @property @@ -245,7 +249,7 @@ async def _listen(self): elif isinstance(event, RequestReceived): self._allocate_stream(event.stream_id) - await self._streams[event.stream_id]._incoming_events.put(event) + await self._streams[event.stream_id]._incoming_events[0].send(event) if isinstance(event, RequestReceived): await yield_(self._streams[event.stream_id]) diff --git a/src/purerpc/rpc.py b/src/purerpc/rpc.py index cfe1455..eff9c59 100644 --- a/src/purerpc/rpc.py +++ b/src/purerpc/rpc.py @@ -1,6 +1,7 @@ import enum import typing import collections +import collections.abc Stream = typing.AsyncIterator diff --git a/src/purerpc/server.py b/src/purerpc/server.py index 1fe6b88..7afd5a8 100644 --- a/src/purerpc/server.py +++ b/src/purerpc/server.py @@ -1,12 +1,15 @@ import sys import inspect -import socket import collections import functools + import async_exit_stack import logging import anyio +import anyio.abc +from anyio import TASK_STATUS_IGNORED +from anyio.streams.tls import TLSListener from async_generator import async_generator, asynccontextmanager, yield_ from .grpclib.events import RequestReceived @@ -71,28 +74,6 @@ def service(self) -> Service: raise NotImplementedError() -def tcp_server_socket(host, port, family=socket.AF_INET, backlog=100, - reuse_address=True, reuse_port=False, ssl_context=None): - raw_socket = socket.socket(family, socket.SOCK_STREAM) - try: - if reuse_address: - raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) - - if reuse_port: - try: - raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True) - except (AttributeError, OSError) as e: - log.warning('reuse_port=True option failed', exc_info=True) - - raw_socket.bind((host, port)) - raw_socket.listen(backlog) - except Exception: - raw_socket.close() - raise - - return raw_socket - - @asynccontextmanager @async_generator async def _service_wrapper(service=None, setup_fn=None, teardown_fn=None): @@ -108,7 +89,7 @@ async def _service_wrapper(service=None, setup_fn=None, teardown_fn=None): class Server: def __init__(self, port=50055, ssl_context=None): self.port = port - self._ssl = ssl_context + self._ssl_context = ssl_context self.services = {} def add_service(self, service=None, context_manager=None, setup_fn=None, teardown_fn=None, name=None): @@ -133,35 +114,38 @@ def add_service(self, service=None, context_manager=None, setup_fn=None, teardow else: raise ValueError("Shouldn't have happened") - def _create_socket_and_listen(self): - return tcp_server_socket('', self.port, reuse_address=True, reuse_port=True) + async def serve_async(self, *, task_status=TASK_STATUS_IGNORED): + """Run the grpc server - async def _run_async_server(self, raw_socket): - socket = anyio._get_asynclib().Socket(raw_socket) + The task_status protocol lets the caller know when the server is + listening, and yields the port number (same given to Server constructor). + """ # TODO: resource usage warning async with async_exit_stack.AsyncExitStack() as stack: - tcp_server = await stack.enter_async_context( - anyio._networking.SocketStreamServer(socket, - self._ssl, - self._ssl is not None, - False) - ) - task_group = await stack.enter_async_context(anyio.create_task_group()) + tcp_server = await anyio.create_tcp_listener(local_port=self.port, reuse_port=True) + # read the resulting port, in case it was 0 + self.port = tcp_server.extra(anyio.abc.SocketAttribute.local_port) + if self._ssl_context: + tcp_server = TLSListener(tcp_server, self._ssl_context, + standard_compatible=False) + task_status.started(self.port) services_dict = {} for key, value in self.services.items(): services_dict[key] = await stack.enter_async_context(value) - async for socket in tcp_server.accept_connections(): - await task_group.spawn(ConnectionHandler(services_dict), socket) - - def _target_fn(self, backend): - socket = self._create_socket_and_listen() - anyio.run(self._run_async_server, socket, backend=backend) + await tcp_server.serve(ConnectionHandler(services_dict)) def serve(self, backend=None): - self._target_fn(backend) + """ + DEPRECATED - use serve_async() instead + + This function runs an entire async event loop (there can only be one + per thread), and there is no way to know when the server is ready for + connections. + """ + anyio.run(self.serve_async, backend=backend) class ConnectionHandler: @@ -227,16 +211,16 @@ async def request_received(self, stream: GRPCProtoStream): log.warning("Got exception in request_received", exc_info=log.getEffectiveLevel() == logging.DEBUG) - async def __call__(self, socket): + async def __call__(self, stream_: anyio.abc.SocketStream): # TODO: Should at least pass through GeneratorExit try: - async with GRPCProtoSocket(self.config, socket) as self.grpc_socket: + async with GRPCProtoSocket(self.config, stream_) as self.grpc_socket: # TODO: resource usage warning # TODO: TaskGroup() uses a lot of memory if the connection is kept for a long time # TODO: do we really need it here? async with anyio.create_task_group() as task_group: async for stream in self.grpc_socket.listen(): - await task_group.spawn(self.request_received, stream) + task_group.start_soon(self.request_received, stream) except: # TODO: limit catch to Exception, so async cancel can propagate log.warning("Got exception in main dispatch loop", diff --git a/src/purerpc/test_utils.py b/src/purerpc/test_utils.py index 045f31d..78641e7 100644 --- a/src/purerpc/test_utils.py +++ b/src/purerpc/test_utils.py @@ -118,25 +118,38 @@ def _run_context_manager_generator_in_process(cm_gen): parent_conn.close() -def run_purerpc_service_in_process(service): +def run_purerpc_service_in_process(service, ssl_context=None): + # TODO: there is no reason to run the server as a separate process... + # just use serve_async(). This synchronous cm has timing problems, + # because the server may not be listening before yielding to the body. + def target_fn(): import purerpc - server = purerpc.Server(port=0) - server.add_service(service) - socket = server._create_socket_and_listen() - yield socket.getsockname()[1] + import socket - async def sleep_10_seconds_then_die(): - await anyio.sleep(20) - raise ValueError + # Grab an ephemeral port in advance, because we need to yield the port + # before blocking on serve()... + with socket.socket() as sock: + sock.bind(('127.0.0.1', 0)) + port = sock.getsockname()[1] - async def main(): - async with anyio.create_task_group() as tg: - await tg.spawn(server._run_async_server, socket) - await tg.spawn(sleep_10_seconds_then_die) + server = purerpc.Server(port=port, ssl_context=ssl_context) + server.add_service(service) + yield port + server.serve() + + # async def sleep_10_seconds_then_die(): + # await anyio.sleep(20) + # raise ValueError + # + # async def main(): + # async with anyio.create_task_group() as tg: + # tg.start_soon(server.serve_async) + # tg.start_soon(sleep_10_seconds_then_die) + # # import cProfile - anyio.run(server._run_async_server, socket) # cProfile.runctx("anyio.run(main)", globals(), locals(), sort="tottime") + return _run_context_manager_generator_in_process(target_fn) @@ -208,7 +221,7 @@ def decorator(corofunc): async def new_corofunc(**kwargs): async with anyio.create_task_group() as tg: for _ in range(num_tasks): - await tg.spawn(functools.partial(corofunc, **kwargs)) + tg.start_soon(functools.partial(corofunc, **kwargs)) return new_corofunc return decorator diff --git a/src/purerpc/wrappers.py b/src/purerpc/wrappers.py index bc37046..ec5e92f 100644 --- a/src/purerpc/wrappers.py +++ b/src/purerpc/wrappers.py @@ -103,7 +103,7 @@ class ClientStubStreamUnary(ClientStub): async def __call__(self, message_aiter, *, metadata=None): stream = await self._stream_fn(metadata=metadata) async with anyio.create_task_group() as task_group: - await task_group.spawn(send_multiple_messages_client, stream, message_aiter) + task_group.start_soon(send_multiple_messages_client, stream, message_aiter) return await extract_message_from_singleton_stream(stream) @@ -112,7 +112,7 @@ class ClientStubStreamStream(ClientStub): async def call_aiter(self, message_aiter, metadata): stream = await self._stream_fn(metadata=metadata) async with anyio.create_task_group() as task_group: - await task_group.spawn(send_multiple_messages_client, stream, message_aiter) + task_group.start_soon(send_multiple_messages_client, stream, message_aiter) await yield_from_(stream_to_async_iterator(stream)) async def call_stream(self, metadata): diff --git a/tests/test_echo.py b/tests/test_echo.py index 850205c..b331b2a 100644 --- a/tests/test_echo.py +++ b/tests/test_echo.py @@ -1,12 +1,34 @@ import functools +import ssl import pytest +import trustme from async_generator import async_generator, yield_ +import purerpc from purerpc.test_utils import run_purerpc_service_in_process, run_grpc_service_in_process, \ async_iterable_to_list, random_payload, grpc_client_parallelize, async_test, purerpc_channel, purerpc_client_parallelize, grpc_channel +@pytest.fixture(scope='module') +def ca(): + return trustme.CA() + + +@pytest.fixture(scope='module') +def server_ssl_context(ca): + server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ca.issue_cert('127.0.0.1').configure_cert(server_context) + return server_context + + +@pytest.fixture(scope='module') +def client_ssl_context(ca): + client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ca.configure_trust(client_context) + return client_context + + @pytest.fixture(scope="module") def purerpc_echo_port(echo_pb2, echo_grpc): class Servicer(echo_grpc.EchoServicer): @@ -37,6 +59,46 @@ async def EchoLastV2(self, messages): await yield_(echo_pb2.EchoReply(data="".join(data))) with run_purerpc_service_in_process(Servicer().service) as port: + # TODO: migrate to serve_async() to avoid timing problems + import time + time.sleep(.1) + yield port + + +@pytest.fixture(scope="module") +def purerpc_echo_port_ssl(echo_pb2, echo_grpc, server_ssl_context): + class Servicer(echo_grpc.EchoServicer): + async def Echo(self, message): + return echo_pb2.EchoReply(data=message.data) + + @async_generator + async def EchoTwoTimes(self, message): + await yield_(echo_pb2.EchoReply(data=message.data)) + await yield_(echo_pb2.EchoReply(data=message.data)) + + @async_generator + async def EchoEachTime(self, messages): + async for message in messages: + await yield_(echo_pb2.EchoReply(data=message.data)) + + async def EchoLast(self, messages): + data = [] + async for message in messages: + data.append(message.data) + return echo_pb2.EchoReply(data="".join(data)) + + @async_generator + async def EchoLastV2(self, messages): + data = [] + async for message in messages: + data.append(message.data) + await yield_(echo_pb2.EchoReply(data="".join(data))) + + with run_purerpc_service_in_process(Servicer().service, + ssl_context=server_ssl_context) as port: + # TODO: migrate to serve_async() to avoid timing problems + import time + time.sleep(.1) yield port @@ -136,3 +198,19 @@ async def gen(): assert [response.data for response in await async_iterable_to_list( stub.EchoLastV2(gen()))] == [data * 20] + + +@async_test +async def test_purerpc_ssl(echo_pb2, echo_grpc, purerpc_echo_port_ssl, client_ssl_context): + async with purerpc.secure_channel("127.0.0.1", purerpc_echo_port_ssl, + ssl_context=client_ssl_context) as channel: + stub = echo_grpc.EchoStub(channel) + data = random_payload(min_size=32000, max_size=64000) + + @async_generator + async def gen(): + for _ in range(20): + await yield_(echo_pb2.EchoRequest(data=data)) + + assert [response.data for response in await async_iterable_to_list( + stub.EchoLastV2(gen()))] == [data * 20] diff --git a/tests/test_server_http2.py b/tests/test_server_http2.py index 9016c61..51bd1d4 100644 --- a/tests/test_server_http2.py +++ b/tests/test_server_http2.py @@ -17,6 +17,8 @@ @pytest.fixture def dummy_server_port(): with run_purerpc_service_in_process(purerpc.Service("Greeter")) as port: + # TODO: migrate to serve_async() to avoid timing problems + time.sleep(0.1) yield port @@ -35,7 +37,7 @@ def http2_client_connect(host, port): sock.close() -def http2_receive_events(conn, sock, timeout=0.1): +def http2_receive_events(conn, sock): try: sock.settimeout(0.1) events = []