Skip to content

Commit 2878496

Browse files
committed
Abort TCP connection if it doesn't close fast enough.
Fix #112.
1 parent 43784ea commit 2878496

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

websockets/protocol.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol):
6666
The ``timeout`` parameter defines the maximum wait time in seconds for
6767
completing the closing handshake and, only on the client side, for
6868
terminating the TCP connection. :meth:`close()` will complete in at most
69-
``3 * timeout`` on the server side and ``4 * timeout`` on the client side.
69+
``4 * timeout`` on the server side and ``5 * timeout`` on the client side.
7070
7171
The ``max_size`` parameter enforces the maximum size for incoming messages
7272
in bytes. The default value is 1MB. ``None`` disables the limit. If a
@@ -773,12 +773,23 @@ def close_connection(self, after_handshake=True):
773773
# Closing a transport is idempotent. If the transport was already
774774
# closed, for example from eof_received(), it's fine.
775775

776-
# Close the TCP connection.
776+
# Close the TCP connection. Buffers are flushed asynchronously.
777777
logger.debug(
778778
"%s x closing TCP connection", self.side)
779779
self.writer.close()
780-
# There's little need to await self.wait_for_connection_lost()
781-
# here. Closing the transport triggers self.connection_lost().
780+
781+
if (yield from self.wait_for_connection_lost()):
782+
return
783+
logger.debug(
784+
"%s ! timed out waiting for TCP close", self.side)
785+
786+
# Abort the TCP connection. Buffers are discarded.
787+
logger.debug(
788+
"%s x aborting TCP connection", self.side)
789+
self.writer.transport.abort()
790+
791+
# connection_lost() is called quickly after aborting.
792+
yield from self.wait_for_connection_lost()
782793

783794
@asyncio.coroutine
784795
def fail_connection(self, code=1011, reason=''):

websockets/test_protocol.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def close(self):
7070
self.loop.call_soon(self.protocol.connection_lost, None)
7171
self._closing = True
7272

73+
def abort(self):
74+
if self.protocol.state != CLOSED:
75+
self.loop.call_soon(self.protocol.connection_lost, None)
76+
7377

7478
class CommonTests:
7579
"""
@@ -789,7 +793,7 @@ def test_local_close_receive_close_frame_timeout(self):
789793
self.loop.run_until_complete(self.protocol.close(reason='close'))
790794
self.assertConnectionClosed(1006, '')
791795

792-
def test_local_close_connection_lost_timeout(self):
796+
def test_local_close_connection_lost_timeout_after_write_eof(self):
793797
self.protocol.timeout = 10 * MS
794798
# If the client doesn't close its side of the TCP connection after we
795799
# half-close our side with write_eof(), time out in 10ms.
@@ -801,6 +805,21 @@ def test_local_close_connection_lost_timeout(self):
801805
self.loop.run_until_complete(self.protocol.close(reason='close'))
802806
self.assertConnectionClosed(1000, 'close')
803807

808+
def test_local_close_connection_lost_timeout_after_close(self):
809+
self.protocol.timeout = 10 * MS
810+
# If the client doesn't close its side of the TCP connection after we
811+
# half-close our side with write_eof() and close it with close(), time
812+
# out in 20ms.
813+
# Check the timing within -1/+9ms for robustness.
814+
with self.assertCompletesWithin(19 * MS, 29 * MS):
815+
# HACK: disable write_eof => other end drops connection emulation.
816+
self.transport._eof = True
817+
# HACK: disable close => other end drops connection emulation.
818+
self.transport._closing = True
819+
self.receive_frame(self.close_frame)
820+
self.loop.run_until_complete(self.protocol.close(reason='close'))
821+
self.assertConnectionClosed(1000, 'close')
822+
804823

805824
class ClientTests(CommonTests, unittest.TestCase):
806825

@@ -830,16 +849,34 @@ def test_local_close_receive_close_frame_timeout(self):
830849
self.loop.run_until_complete(self.protocol.close(reason='close'))
831850
self.assertConnectionClosed(1006, '')
832851

833-
def test_local_close_connection_lost_timeout(self):
852+
def test_local_close_connection_lost_timeout_after_write_eof(self):
834853
self.protocol.timeout = 10 * MS
835854
# If the server doesn't half-close its side of the TCP connection
836855
# after we send a close frame, time out in 20ms:
837856
# - 10ms waiting for receiving a half-close
838-
# - 10ms waiting for receiving a close
857+
# - 10ms waiting for receiving a close after write_eof
839858
# Check the timing within -1/+9ms for robustness.
840859
with self.assertCompletesWithin(19 * MS, 29 * MS):
841860
# HACK: disable write_eof => other end drops connection emulation.
842861
self.transport._eof = True
843862
self.receive_frame(self.close_frame)
844863
self.loop.run_until_complete(self.protocol.close(reason='close'))
845864
self.assertConnectionClosed(1000, 'close')
865+
866+
def test_local_close_connection_lost_timeout_after_close(self):
867+
self.protocol.timeout = 10 * MS
868+
# If the client doesn't close its side of the TCP connection after we
869+
# half-close our side with write_eof() and close it with close(), time
870+
# out in 20ms.
871+
# - 10ms waiting for receiving a half-close
872+
# - 10ms waiting for receiving a close after write_eof
873+
# - 10ms waiting for receiving a close after close
874+
# Check the timing within -1/+9ms for robustness.
875+
with self.assertCompletesWithin(29 * MS, 39 * MS):
876+
# HACK: disable write_eof => other end drops connection emulation.
877+
self.transport._eof = True
878+
# HACK: disable close => other end drops connection emulation.
879+
self.transport._closing = True
880+
self.receive_frame(self.close_frame)
881+
self.loop.run_until_complete(self.protocol.close(reason='close'))
882+
self.assertConnectionClosed(1000, 'close')

0 commit comments

Comments
 (0)