Skip to content

Commit

Permalink
fix client hang when connection lost just after remote closes (#182)
Browse files Browse the repository at this point in the history
Bad ordering:
    1. Remote close
    2. TCP closed
    3. Local confirms
    => no ConnectionClosed raised, client hangs forever
  • Loading branch information
orausch committed Sep 6, 2023
1 parent 89f2749 commit b761d41
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 37 deletions.
40 changes: 39 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import trio
import trustme
import wsproto
from trio.testing import memory_stream_pair
from trio.testing import memory_stream_pair, memory_stream_pump
from wsproto.events import CloseConnection

try:
Expand Down Expand Up @@ -1017,3 +1017,41 @@ async def test_finalization_dropped_exception(echo_server, autojump_clock):
await trio.sleep_forever()
finally:
raise ValueError


async def test_remote_close_rude():
"""
Bad ordering:
1. Remote close
2. TCP closed
3. Local confirms
=> no ConnectionClosed raised, client hangs forever
"""
client_stream, server_stream = memory_stream_pair()

async def client():
client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE)
assert not client_conn.closed
await client_conn.send_message('Hello from client!')
with pytest.raises(ConnectionClosed):
await client_conn.get_message()

async def server():
server_request = await wrap_server_stream(nursery, server_stream)
server_ws = await server_request.accept()
assert not server_ws.closed
msg = await server_ws.get_message()
assert msg == "Hello from client!"

# disable pumping so that the CloseConnection arrives at the same time as the stream closure
server_stream.send_stream.send_all_hook = None
await server_ws._send(CloseConnection(code=1000, reason=None))
await server_stream.aclose()

# pump the messages over
memory_stream_pump(server_stream.send_stream, client_stream.receive_stream)


async with trio.open_nursery() as nursery:
nursery.start_soon(server)
nursery.start_soon(client)
73 changes: 37 additions & 36 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,45 +1210,46 @@ async def _reader_task(self):
except ConnectionClosed:
self._reader_running = False

while self._reader_running:
# Process events.
for event in self._wsproto.events():
event_type = type(event)
async with self._send_channel:
while self._reader_running:
# Process events.
for event in self._wsproto.events():
event_type = type(event)
try:
handler = handlers[event_type]
logger.debug('%s received event: %s', self,
event_type)
await handler(event)
except KeyError:
logger.warning('%s received unknown event type: "%s"', self,
event_type)
except ConnectionClosed:
self._reader_running = False
break

# Get network data.
try:
handler = handlers[event_type]
logger.debug('%s received event: %s', self,
event_type)
await handler(event)
except KeyError:
logger.warning('%s received unknown event type: "%s"', self,
event_type)
except ConnectionClosed:
self._reader_running = False
data = await self._stream.receive_some(RECEIVE_BYTES)
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self._abort_web_socket()
break

# Get network data.
try:
data = await self._stream.receive_some(RECEIVE_BYTES)
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self._abort_web_socket()
break
if len(data) == 0:
logger.debug('%s received zero bytes (connection closed)',
self)
# If TCP closed before WebSocket, then record it as an abnormal
# closure.
if len(data) == 0:
logger.debug('%s received zero bytes (connection closed)',
self)
# If TCP closed before WebSocket, then record it as an abnormal
# closure.
if self._wsproto.state != ConnectionState.CLOSED:
await self._abort_web_socket()
break
logger.debug('%s received %d bytes', self, len(data))
if self._wsproto.state != ConnectionState.CLOSED:
await self._abort_web_socket()
break
logger.debug('%s received %d bytes', self, len(data))
if self._wsproto.state != ConnectionState.CLOSED:
try:
self._wsproto.receive_data(data)
except wsproto.utilities.RemoteProtocolError as err:
logger.debug('%s remote protocol error: %s', self, err)
if err.event_hint:
await self._send(err.event_hint)
await self._close_stream()
try:
self._wsproto.receive_data(data)
except wsproto.utilities.RemoteProtocolError as err:
logger.debug('%s remote protocol error: %s', self, err)
if err.event_hint:
await self._send(err.event_hint)
await self._close_stream()

logger.debug('%s reader task finished', self)

Expand Down

0 comments on commit b761d41

Please sign in to comment.