Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ notice.

.. _14.0:

Bug fixes
.........

* Once the connection is closed, messages previously received and buffered can
be read in the :mod:`asyncio` and :mod:`threading` implementations, just like
in the legacy implementation.

14.0
----

Expand Down
33 changes: 12 additions & 21 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def put(self, item: T) -> None:
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_result(None)

async def get(self) -> T:
async def get(self, block: bool = True) -> T:
"""Remove and return an item from the queue, waiting if necessary."""
if not self.queue:
if self.get_waiter is not None:
raise ConcurrencyError("get is already running")
if not block:
raise EOFError("stream of frames ended")
assert self.get_waiter is None, "cannot call get() concurrently"
self.get_waiter = self.loop.create_future()
try:
await self.get_waiter
Expand All @@ -63,8 +64,6 @@ def abort(self) -> None:
"""Close the queue, raising EOFError in get() if necessary."""
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_exception(EOFError("stream of frames ended"))
# Clear the queue to avoid storing unnecessary data in memory.
self.queue.clear()


class Assembler:
Expand Down Expand Up @@ -136,20 +135,16 @@ async def get(self, decode: bool | None = None) -> Data:
:meth:`get_iter` concurrently.

"""
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution until
# get() fetches a complete message or is cancelled.
# Locking with get_in_progress prevents concurrent execution
# until get() fetches a complete message or is cancelled.

try:
# First frame
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
Expand All @@ -159,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data:
# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
Expand Down Expand Up @@ -203,23 +198,19 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
:meth:`get_iter` concurrently.

"""
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution until
# get_iter() fetches a complete message or is cancelled.
# Locking with get_in_progress prevents concurrent execution
# until get_iter() fetches a complete message or is cancelled.

# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.

# First frame
try:
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
except asyncio.CancelledError:
self.get_in_progress = False
raise
Expand All @@ -239,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
# previous fragments — we're streaming them. Canceling get_iter()
# here will leave the assembler in a stuck state. Future calls to
# get() or get_iter() will raise ConcurrencyError.
frame = await self.frames.get()
frame = await self.frames.get(not self.closed)
self.maybe_resume()
assert frame.opcode is OP_CONT
if decode:
Expand Down
38 changes: 19 additions & 19 deletions src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ def __init__(
def get_next_frame(self, timeout: float | None = None) -> Frame:
# Helper to factor out the logic for getting the next frame from the
# queue, while handling timeouts and reaching the end of the stream.
try:
frame = self.frames.get(timeout=timeout)
except queue.Empty:
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
if self.closed:
try:
frame = self.frames.get(block=False)
except queue.Empty:
raise EOFError("stream of frames ended") from None
else:
try:
frame = self.frames.get(block=True, timeout=timeout)
except queue.Empty:
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
if frame is None:
raise EOFError("stream of frames ended")
return frame
Expand All @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
queued = []
try:
while True:
queued.append(self.frames.get_nowait())
queued.append(self.frames.get(block=False))
except queue.Empty:
pass
for frame in frames:
Expand Down Expand Up @@ -123,15 +129,13 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:

"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one thread can get here.
self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution
# until get() fetches a complete message or times out.

try:
deadline = Deadline(timeout)

Expand Down Expand Up @@ -193,17 +197,12 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]:

"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")

if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution until
# get_iter() fetches a complete message or is cancelled.
# Locking with get_in_progress prevents concurrent execution
# until get_iter() fetches a complete message or times out.

# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.
Expand Down Expand Up @@ -289,5 +288,6 @@ def close(self) -> None:

self.closed = True

# Unblock get() or get_iter().
self.frames.put(None)
if self.get_in_progress:
# Unblock get() or get_iter().
self.frames.put(None)
8 changes: 3 additions & 5 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self):
# Remove socket.timeout when dropping Python < 3.10.
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))

async def test_close_does_not_wait_for_recv(self):
# Closing the connection discards messages buffered in the assembler.
# This is allowed by the RFC:
# > However, there is no guarantee that the endpoint that has already
# > sent a Close frame will continue to process data.
async def test_close_preserves_queued_messages(self):
"""close preserves messages buffered in the assembler."""
await self.remote_connection.send("😀")
await self.connection.close()

self.assertEqual(await self.connection.recv(), "😀")
with self.assertRaises(ConnectionClosedOK) as raised:
await self.connection.recv()

Expand Down
67 changes: 52 additions & 15 deletions tests/asyncio/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ async def test_get_then_put(self):
item = await getter_task
self.assertEqual(item, 42)

async def test_get_concurrently(self):
"""get cannot be called concurrently."""
getter_task = asyncio.create_task(self.queue.get())
await asyncio.sleep(0) # let the task start
with self.assertRaises(ConcurrencyError):
await self.queue.get()
getter_task.cancel()

async def test_reset(self):
"""reset sets the content of the queue."""
self.queue.reset([42])
Expand All @@ -59,13 +51,6 @@ async def test_abort(self):
with self.assertRaises(EOFError):
await getter_task

async def test_abort_clears_queue(self):
"""abort clears buffered data from the queue."""
self.queue.put(42)
self.assertEqual(len(self.queue), 1)
self.queue.abort()
self.assertEqual(len(self.queue), 0)


class AssemblerTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
Expand Down Expand Up @@ -410,6 +395,58 @@ async def test_get_iter_fails_after_close(self):
async for _ in self.assembler.get_iter():
self.fail("no fragment expected")

async def test_get_queued_message_after_close(self):
"""get returns a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
message = await self.assembler.get()
self.assertEqual(message, "café")

async def test_get_iter_queued_message_after_close(self):
"""get_iter yields a message after close is called."""
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
self.assembler.close()
fragments = await alist(self.assembler.get_iter())
self.assertEqual(fragments, ["café"])

async def test_get_queued_fragmented_message_after_close(self):
"""get reassembles a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
self.assembler.close()
message = await self.assembler.get()
self.assertEqual(message, b"tea")

async def test_get_iter_queued_fragmented_message_after_close(self):
"""get_iter yields a fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_CONT, b"a"))
self.assembler.close()
fragments = await alist(self.assembler.get_iter())
self.assertEqual(fragments, [b"t", b"e", b"a"])

async def test_get_partially_queued_fragmented_message_after_close(self):
"""get raises EOF on a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
with self.assertRaises(EOFError):
await self.assembler.get()

async def test_get_iter_partially_queued_fragmented_message_after_close(self):
"""get_iter yields a partial fragmented message after close is called."""
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.close()
fragments = []
with self.assertRaises(EOFError):
async for fragment in self.assembler.get_iter():
fragments.append(fragment)
self.assertEqual(fragments, [b"t", b"e"])

async def test_put_fails_after_close(self):
"""put raises EOFError after close is called."""
self.assembler.close()
Expand Down
19 changes: 7 additions & 12 deletions tests/sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self):
# Remove socket.timeout when dropping Python < 3.10.
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))

def test_close_does_not_wait_for_recv(self):
# Closing the connection discards messages buffered in the assembler.
# This is allowed by the RFC:
# > However, there is no guarantee that the endpoint that has already
# > sent a Close frame will continue to process data.
def test_close_preserves_queued_messages(self):
"""close preserves messages buffered in the assembler."""
self.remote_connection.send("😀")
self.connection.close()

close_thread = threading.Thread(target=self.connection.close)
close_thread.start()

self.assertEqual(self.connection.recv(), "😀")
with self.assertRaises(ConnectionClosedOK) as raised:
self.connection.recv()

Expand All @@ -576,10 +571,10 @@ def test_close_idempotency(self):
def test_close_idempotency_race_condition(self):
"""close waits if the connection is already closing."""

self.connection.close_timeout = 5 * MS
self.connection.close_timeout = 6 * MS

def closer():
with self.delay_frames_rcvd(3 * MS):
with self.delay_frames_rcvd(4 * MS):
self.connection.close()

close_thread = threading.Thread(target=closer)
Expand All @@ -591,14 +586,14 @@ def closer():

# Connection isn't closed yet.
with self.assertRaises(TimeoutError):
self.connection.recv(timeout=0)
self.connection.recv(timeout=MS)

self.connection.close()
self.assertNoFrameSent()

# Connection is closed now.
with self.assertRaises(ConnectionClosedOK):
self.connection.recv(timeout=0)
self.connection.recv(timeout=MS)

close_thread.join()

Expand Down
Loading