From 42c291873f19c6533618f9354c18f7196b21fa8a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 11:49:15 +0200 Subject: [PATCH] Add support for reconnecting automatically. Fix #414. --- docs/howto/logging.rst | 1 + docs/project/changelog.rst | 3 ++ docs/spelling_wordlist.txt | 2 + src/websockets/legacy/client.py | 69 +++++++++++++++++++++++--- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 79 ++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index f69ee47b9..824812959 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -210,6 +210,7 @@ Here's what websockets logs at each level. * Server starting and stopping * Server establishing and closing connections +* Client reconnecting automatically :attr:`~logging.DEBUG` ...................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 603466593..6fd9b0c3b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,6 +51,9 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added support for reconnecting automatically by using + :func:`~legacy.client.connect` as an asynchronous iterator. + * Added ``open_timeout`` to :func:`~legacy.client.connect`. * Improved logging. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b460ef033..8346acefa 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -6,6 +6,7 @@ autoscaler awaitable aymeric backend +backoff backpressure balancer balancers @@ -52,6 +53,7 @@ pong pongs proxying pythonic +reconnection redis retransmit runtime diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index df1a2f57c..20e9c1079 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -9,7 +9,18 @@ import logging import warnings from types import TracebackType -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, Type, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + cast, +) from ..datastructures import Headers, HeadersLike from ..exceptions import ( @@ -413,12 +424,23 @@ class Connect: Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - :func:`connect` can also be used as a asynchronous context manager:: + :func:`connect` can be used as a asynchronous context manager:: async with connect(...) as websocket: ... - In that case, the connection is closed when exiting the context. + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + ... + + You must catch all exceptions, or else you will exit the loop prematurely. + As above, connections are closed automatically. Connection attempts are + delayed with exponential backoff, starting at three seconds and + increasing up to one minute. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments @@ -577,6 +599,10 @@ def __init__( ) self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger + # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri @@ -615,7 +641,38 @@ def handle_redirect(self, uri: str) -> None: # Set the new WebSocket URI. This suffices for same-origin redirects. self._wsuri = new_wsuri - # async with connect(...) + # async for ... in connect(...): + + BACKOFF_MIN = 2.0 + BACKOFF_MAX = 60.0 + BACKOFF_FACTOR = 1.5 + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN + while True: + try: + async with self as protocol: + yield protocol + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. + except asyncio.CancelledError: # pragma: no cover + raise + except Exception: + # Connection timed out - increase backoff delay + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + self.logger.info( + "! connect failed; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(backoff_delay) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + # async with connect(...) as ...: async def __aenter__(self) -> WebSocketClientProtocol: return await self @@ -628,7 +685,7 @@ async def __aexit__( ) -> None: await self.protocol.close() - # await connect(...) + # ... = await connect(...) def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. @@ -665,7 +722,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: else: raise SecurityError("too many redirects") - # yield from connect(...) + # ... = yield from connect(...) __iter__ = __await__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1704ae083..5dd410246 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -904,7 +904,7 @@ class Serve: :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their current or next interaction with the WebSocket connection. - :func:`serve` can also be used as an asynchronous context manager:: + :func:`serve` can be used as an asynchronous context manager:: stop = asyncio.Future() # set this future to exit the server diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 60c0a14ae..6f48742a4 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1474,6 +1474,85 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) +class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): + async def echo_handler(ws, path): + async for msg in ws: + await ws.send(msg) + + service_available = True + + async def maybe_service_unavailable(path, headers): + if not ReconnectionTests.service_available: + return http.HTTPStatus.SERVICE_UNAVAILABLE, [], b"" + + async def disable_server(self, duration): + ReconnectionTests.service_available = False + await asyncio.sleep(duration) + ReconnectionTests.service_available = True + + @with_server(handler=echo_handler, process_request=maybe_service_unavailable) + def test_reconnect(self): + # Big, ugly integration test :-( + + async def run_client(): + iteration = 0 + connect_inst = connect(get_server_uri(self.server)) + connect_inst.BACKOFF_MIN = 10 * MS + connect_inst.BACKOFF_MAX = 200 * MS + async for ws in connect_inst: + await ws.send("spam") + msg = await ws.recv() + self.assertEqual(msg, "spam") + + iteration += 1 + if iteration == 1: + # Exit block normally. + pass + elif iteration == 2: + # Disable server for a little bit + asyncio.create_task(self.disable_server(70 * MS)) + await asyncio.sleep(0) + elif iteration == 3: + # Exit block after catching connection error. + server_ws = next(iter(self.server.websockets)) + await server_ws.close() + with self.assertRaises(ConnectionClosed): + await ws.recv() + else: + # Exit block with an exception. + raise Exception("BOOM!") + + with self.assertLogs("websockets", logging.INFO) as logs: + with self.assertRaisesRegex(Exception, "BOOM!"): + self.loop.run_until_complete(run_client()) + + self.assertEqual( + [record.getMessage() for record in logs.records], + [ + # Iteration 1 + "connection open", + "connection closed", + # Iteration 2 + "connection open", + "connection closed", + # Iteration 3 + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection open", + "connection closed", + # Iteration 4 + "connection open", + ], + ) + + class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): def test_logger_client(self): with self.assertLogs("test.server", logging.DEBUG) as server_logs: