From cf51aa43c47596ada05c19300476d09f399db860 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Tue, 6 Nov 2018 17:56:24 +0900 Subject: [PATCH 1/5] add test for client context manager bug --- tests/test_connection.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index 33c194e..3cee2bb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -380,3 +380,14 @@ async def handler(request): with pytest.raises(ConnectionClosed) as e: await connection.get_message() assert e.reason.name == 'NORMAL_CLOSURE' + + +async def test_client_cm_exit_with_pending_messages(echo_server, autojump_clock): + with trio.fail_after(1): + async with open_websocket(HOST, echo_server.port, RESOURCE, + use_ssl=False) as ws: + await ws.send_message('hello') + # allow time for the server to respond + await trio.sleep(.1) + # bug: context manager exit is blocked on unconsumed message + #await ws.get_message() From aa47e2bb8561b2560374a91a52d7bc46243c1093 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Wed, 7 Nov 2018 16:40:30 -0500 Subject: [PATCH 2/5] Delay connection closed (#69) As described in the issue, get_message() was raising connection closed even if there were pending messages. Per Nathaniel's suggestion, the proper behavior is this: 1. If the remote endpoint closed the connection, then the local endpoint may continue reading all messages sent prior to closing. 2. If the local endpoint closed the connection, then the local endpoint may not read any more messages. I added tests for these two conditions and implemented the behavior by closing the ReceiveChannel inside the connection's `aclose()`. This requires a bit of additional exception handling inside `get_message()` and inside the reader task. One slight surprise is that the test can't be written due to the bug in #74! The client would hang because the reader task is blocked by the unconsumed messages. So I changed the channel size to 32, which allows this test to work, and I will replace this hard-coded value when I fix #74. --- tests/test_connection.py | 59 ++++++++++++++++++++++++++++++++++++++-- trio_websocket/_impl.py | 15 ++++++---- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 78a6854..fa0bdf0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -377,6 +377,61 @@ async def handler(request): with trio.fail_after(2): async with open_websocket( HOST, server.port, '/', use_ssl=False) as connection: - with pytest.raises(ConnectionClosed) as e: + with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() - assert e.reason.name == 'NORMAL_CLOSURE' + exc = exc_info.value + assert exc.reason.name == 'NORMAL_CLOSURE' + + +async def test_read_messages_after_remote_close(nursery): + ''' + When the remote endpoint closes, the local endpoint can still reading all + of the messages sent prior to closing. Any attempt to read beyond that will + raise ConnectionClosed. + ''' + server_closed = trio.Event() + + async def handler(request): + server = await request.accept() + async with server: + await server.send_message('1') + await server.send_message('2') + server_closed.set() + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + await server_closed.wait() + assert await client.get_message() == '1' + assert await client.get_message() == '2' + with pytest.raises(ConnectionClosed): + await client.get_message() + + +async def test_no_messages_after_local_close(nursery): + ''' + If the local endpoint initiates closing, then pending messages are discarded + and any attempt to read a message will raise ConnectionClosed. + ''' + client_closed = trio.Event() + + async def handler(request): + # The server sends some messages and then closes. + server = await request.accept() + async with server: + await server.send_message('1') + await server.send_message('2') + await client_closed.wait() + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + # The client waits until the server closes (using an out-of-band trio Event) + # and then reads the messages. After reading all messages, it should raise + # ConnectionClosed. + async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + pass + with pytest.raises(ConnectionClosed): + await client.get_message() + client_closed.set() diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 38977c4..603ad1b 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -445,7 +445,7 @@ def __init__(self, stream, wsproto, *, path=None): self._reader_running = True self._path = path self._subprotocol = None - self._send_channel, self._recv_channel = trio.open_memory_channel(0) + self._send_channel, self._recv_channel = trio.open_memory_channel(32) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. @@ -520,6 +520,7 @@ async def aclose(self, code=1000, reason=None): return self._wsproto.close(code=code, reason=reason) try: + await self._recv_channel.aclose() await self._write_pending() await self._close_handshake.wait() finally: @@ -538,11 +539,9 @@ async def get_message(self): :raises ConnectionClosed: if connection is closed before a message arrives. ''' - if self._close_reason: - raise ConnectionClosed(self._close_reason) try: message = await self._recv_channel.receive() - except trio.EndOfChannel: + except (trio.ClosedResourceError, trio.EndOfChannel): raise ConnectionClosed(self._close_reason) from None return message @@ -739,7 +738,13 @@ async def _handle_text_received_event(self, event): ''' self._str_message += event.data if event.message_finished: - await self._send_channel.send(self._str_message) + try: + await self._send_channel.send(self._str_message) + except trio.BrokenResourceError: + # The receive channel is closed, probably because somebody + # called ``aclose()``. We don't want to abort the reader task, + # and there's no useful cleanup that we can do here. + pass self._str_message = '' async def _handle_ping_received_event(self, event): From e33a36eedd2c7769f638a4c34b64475315b697d7 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Thu, 8 Nov 2018 11:42:13 +0900 Subject: [PATCH 3/5] refactor WebsocketConnection received data handling: * combine BytesReceived and TextReceived handlers * use join to build message rather than += --- trio_websocket/_impl.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 38977c4..2aa4f84 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -14,6 +14,7 @@ import trio.ssl import wsproto.connection as wsconnection import wsproto.frame_protocol as wsframeproto +from wsproto.events import BytesReceived from yarl import URL from .version import __version__ @@ -440,8 +441,7 @@ def __init__(self, stream, wsproto, *, path=None): self._stream = stream self._stream_lock = trio.StrictFIFOLock() self._wsproto = wsproto - self._bytes_message = b'' - self._str_message = '' + self._message_parts = [] # type: List[bytes|str] self._reader_running = True self._path = path self._subprotocol = None @@ -720,27 +720,18 @@ async def _handle_connection_failed_event(self, event): self._open_handshake.set() self._close_handshake.set() - async def _handle_bytes_received_event(self, event): + async def _handle_data_received_event(self, event): ''' - Handle a BytesReceived event. + Handle a BytesReceived or TextReceived event. :param event: ''' - self._bytes_message += event.data + self._message_parts.append(event.data) if event.message_finished: - await self._send_channel.send(self._bytes_message) - self._bytes_message = b'' - - async def _handle_text_received_event(self, event): - ''' - Handle a TextReceived event. - - :param event: - ''' - self._str_message += event.data - if event.message_finished: - await self._send_channel.send(self._str_message) - self._str_message = '' + msg = (b'' if isinstance(event, BytesReceived) else '') \ + .join(self._message_parts) + await self._send_channel.send(msg) + self._message_parts = [] async def _handle_ping_received_event(self, event): ''' @@ -790,8 +781,8 @@ async def _reader_task(self): 'ConnectionFailed': self._handle_connection_failed_event, 'ConnectionEstablished': self._handle_connection_established_event, 'ConnectionClosed': self._handle_connection_closed_event, - 'BytesReceived': self._handle_bytes_received_event, - 'TextReceived': self._handle_text_received_event, + 'BytesReceived': self._handle_data_received_event, + 'TextReceived': self._handle_data_received_event, 'PingReceived': self._handle_ping_received_event, 'PongReceived': self._handle_pong_received_event, } From 99326934498927f0f9385fa460e9e3dc4e900b66 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Thu, 8 Nov 2018 13:06:52 -0500 Subject: [PATCH 4/5] Code review for #79 --- tests/test_connection.py | 2 +- trio_websocket/_impl.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index fa0bdf0..f7b0a39 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -385,7 +385,7 @@ async def handler(request): async def test_read_messages_after_remote_close(nursery): ''' - When the remote endpoint closes, the local endpoint can still reading all + When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. ''' diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 603ad1b..331a758 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -418,7 +418,8 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() - def __init__(self, stream, wsproto, *, path=None): + def __init__(self, stream, wsproto, *, path=None, + message_queue_size_do_not_use=2): ''' Constructor. @@ -445,7 +446,11 @@ def __init__(self, stream, wsproto, *, path=None): self._reader_running = True self._path = path self._subprotocol = None - self._send_channel, self._recv_channel = trio.open_memory_channel(32) + # TODO changed channel size from 0 to 2 temporarily to enable + # test_read_messages_after_remote_close to pass. The channel size will + # become a configurable setting when #74 is fixed. + self._send_channel, self._recv_channel = trio.open_memory_channel( + message_queue_size_do_not_use) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. From 1ca82e68ab28ccc92af50d19f0908af198380104 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Fri, 9 Nov 2018 13:18:48 -0500 Subject: [PATCH 5/5] Code review feedback (#69) * Leave the default channel size hard coded to zero. * Skip the test that requires non-zero channel size. * Fix incomplete doc string on get_message(). --- tests/test_connection.py | 4 +--- trio_websocket/_impl.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 0c29efe..3b9aa71 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -383,6 +383,7 @@ async def handler(request): assert exc.reason.name == 'NORMAL_CLOSURE' +@pytest.mark.skip(reason='Hangs because channel size is hard coded to 0') async def test_read_messages_after_remote_close(nursery): ''' When the remote endpoint closes, the local endpoint can still read all @@ -427,9 +428,6 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) - # The client waits until the server closes (using an out-of-band trio Event) - # and then reads the messages. After reading all messages, it should raise - # ConnectionClosed. async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: pass with pytest.raises(ConnectionClosed): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 41cc97b..d3d2d23 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -419,8 +419,7 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() - def __init__(self, stream, wsproto, *, path=None, - message_queue_size_do_not_use=2): + def __init__(self, stream, wsproto, *, path=None): ''' Constructor. @@ -446,11 +445,7 @@ def __init__(self, stream, wsproto, *, path=None, self._reader_running = True self._path = path self._subprotocol = None - # TODO changed channel size from 0 to 2 temporarily to enable - # test_read_messages_after_remote_close to pass. The channel size will - # become a configurable setting when #74 is fixed. - self._send_channel, self._recv_channel = trio.open_memory_channel( - message_queue_size_do_not_use) + self._send_channel, self._recv_channel = trio.open_memory_channel(0) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. @@ -538,11 +533,17 @@ async def get_message(self): Receive the next WebSocket message. If no message is available immediately, then this function blocks until - a message is ready. When the connection is closed, this message + a message is ready. + + If the remote endpoint closes the connection, then the caller can still + get messages sent prior to closing. Once all pending messages have been + retrieved, additional calls to this method will raise + ``ConnectionClosed``. If the local endpoint closes the connection, then + pending messages are discarded and calls to this method will immediately + raise ``ConnectionClosed``. :rtype: str or bytes - :raises ConnectionClosed: if connection is closed before a message - arrives. + :raises ConnectionClosed: if the connection is closed. ''' try: message = await self._recv_channel.receive()