From 86bf0c5afc4c88429231ac7ba5f857fd32a02d0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:40:55 +0100 Subject: [PATCH 1/4] Remove unnecessary branch. --- src/websockets/asyncio/messages.py | 3 +-- tests/asyncio/test_messages.py | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index b57c0ca4e..69636e3d0 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -43,8 +43,7 @@ def put(self, item: T) -> None: async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: - if self.get_waiter is not None: - raise ConcurrencyError("get is already running") + assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 2ff929d3a..5c9ac9445 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -37,14 +37,6 @@ async def test_get_then_put(self): item = await getter_task self.assertEqual(item, 42) - async def test_get_concurrently(self): - """get cannot be called concurrently.""" - getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start - with self.assertRaises(ConcurrencyError): - await self.queue.get() - getter_task.cancel() - async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) From f17c11ab6b46cfe6817cbab5d92ba2626d3e87d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:42:16 +0100 Subject: [PATCH 2/4] Keep queued messages after abort. --- src/websockets/asyncio/messages.py | 2 -- tests/asyncio/test_messages.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 69636e3d0..678a3c14e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -62,8 +62,6 @@ def abort(self) -> None: """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) - # Clear the queue to avoid storing unnecessary data in memory. - self.queue.clear() class Assembler: diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 5c9ac9445..181ffd376 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -51,13 +51,6 @@ async def test_abort(self): with self.assertRaises(EOFError): await getter_task - async def test_abort_clears_queue(self): - """abort clears buffered data from the queue.""" - self.queue.put(42) - self.assertEqual(len(self.queue), 1) - self.queue.abort() - self.assertEqual(len(self.queue), 0) - class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): From bdfc8cf90301b528eebfbd0ef8a31e5a2fc7d5f7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:11:02 +0100 Subject: [PATCH 3/4] Uniformize comments. --- src/websockets/asyncio/messages.py | 8 ++++---- src/websockets/sync/messages.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 678a3c14e..814a3c03c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -141,8 +141,8 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is cancelled. try: # First frame @@ -208,8 +208,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is cancelled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 17f8dce7e..ce08172b2 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -128,10 +128,11 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + try: deadline = Deadline(timeout) @@ -198,12 +199,10 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. From 303483412dc5b420d09c1421792b3f8b99c323e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:30:36 +0100 Subject: [PATCH 4/4] Support recv() after the connection is closed. Fix #1538. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/messages.py | 20 ++++-------- src/websockets/sync/messages.py | 27 ++++++++-------- tests/asyncio/test_connection.py | 8 ++--- tests/asyncio/test_messages.py | 52 ++++++++++++++++++++++++++++++ tests/sync/test_connection.py | 19 ++++------- tests/sync/test_messages.py | 52 ++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 43 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9e6a9d113..074a81c85 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + 14.0 ---- diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 814a3c03c..14ea7bf90 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -40,9 +40,11 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: + if not block: + raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: @@ -133,12 +135,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -146,7 +144,7 @@ async def get(self, decode: bool | None = None) -> Data: try: # First frame - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ce08172b2..af8635f16 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -69,10 +69,16 @@ def __init__( def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. - try: - frame = self.frames.get(timeout=timeout) - except queue.Empty: - raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: queued = [] try: while True: - queued.append(self.frames.get_nowait()) + queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -288,5 +288,6 @@ def close(self) -> None: self.closed = True - # Unblock get() or get_iter(). - self.frames.put(None) + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 902b3b847..b1c57c8ca 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 181ffd376..566f71cea 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index abdfd3f78..408b9697a 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -576,10 +571,10 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -591,14 +586,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 02513894a..9ebe45088 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self): for _ in self.assembler.get_iter(): self.fail("no fragment expected") + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close()