From 40095e3c21ee6b8a74437b1cfb058436359d854e Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 23 Jan 2022 19:58:51 +0200 Subject: [PATCH 1/2] Added retry mechanism on socket timeouts when connecting to the server --- redis/connection.py | 6 ++++-- redis/retry.py | 6 +++++- tests/test_connection.py | 44 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 5fdac54c00..508c1961e5 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -604,7 +604,9 @@ def connect(self): if self._sock: return try: - sock = self._connect() + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) except socket.timeout: raise TimeoutError("Timeout connecting to server") except OSError as e: @@ -721,7 +723,7 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") - def disconnect(self): + def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() if self._sock is None: diff --git a/redis/retry.py b/redis/retry.py index 6147fbd9f9..3dced35d24 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,3 +1,4 @@ +import socket from time import sleep from redis.exceptions import ConnectionError, TimeoutError @@ -7,7 +8,10 @@ class Retry: """Retry a specific number of times after a failure""" def __init__( - self, backoff, retries, supported_errors=(ConnectionError, TimeoutError) + self, + backoff, + retries, + supported_errors=(ConnectionError, TimeoutError, socket.timeout), ): """ Initialize a `Retry` object with a `Backoff` object diff --git a/tests/test_connection.py b/tests/test_connection.py index d94a815159..b2b6663fed 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,10 +1,14 @@ +import socket import types from unittest import mock +from unittest.mock import patch import pytest +from redis.backoff import NoBackoff from redis.connection import Connection -from redis.exceptions import InvalidResponse +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE from .conftest import skip_if_server_version_lt @@ -74,3 +78,41 @@ def test_disconnect__close_OSError(self): mock_sock.shutdown.assert_called_once() mock_sock.close.assert_called_once() assert conn._sock is None + + def test_retry_connect_on_timeout_error(self): + """Test that the _connect function is retried in case of a timeout""" + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) + origin_connect = conn._connect + conn._connect = mock.Mock() + + def mock_connect(): + # connect only on the last retry + if conn._connect.call_count <= 2: + raise socket.timeout + else: + return origin_connect() + + conn._connect.side_effect = mock_connect + conn.connect() + assert conn._connect.call_count == 3 + + def test_connect_without_retry_on_os_error(self): + """Test that the _connect function is not being retried in case of a OSError""" + with patch.object(Connection, "_connect") as _connect: + _connect.side_effect = OSError("") + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2)) + with pytest.raises(ConnectionError): + conn.connect() + assert _connect.call_count == 1 + + def test_connect_timeout_error_without_retry(self): + """Test that the _connect function is not being retried if retry_on_timeout is + set to False (default value)""" + conn = Connection() + conn._connect = mock.Mock() + conn._connect.side_effect = socket.timeout + + with pytest.raises(TimeoutError) as e: + conn.connect() + assert conn._connect.call_count == 1 + assert str(e.value) == "Timeout connecting to server" From 8f6d4d4adbb62b5e967eb9e540b77983a1339f34 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 24 Jan 2022 10:11:13 +0200 Subject: [PATCH 2/2] Clear retry_on_error list at the end of each test to fix tests --- tests/test_connection.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index b2b6663fed..d9251c31dc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -79,6 +79,9 @@ def test_disconnect__close_OSError(self): mock_sock.close.assert_called_once() assert conn._sock is None + def clear(self, conn): + conn.retry_on_error.clear() + def test_retry_connect_on_timeout_error(self): """Test that the _connect function is retried in case of a timeout""" conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) @@ -95,6 +98,7 @@ def mock_connect(): conn._connect.side_effect = mock_connect conn.connect() assert conn._connect.call_count == 3 + self.clear(conn) def test_connect_without_retry_on_os_error(self): """Test that the _connect function is not being retried in case of a OSError""" @@ -104,11 +108,12 @@ def test_connect_without_retry_on_os_error(self): with pytest.raises(ConnectionError): conn.connect() assert _connect.call_count == 1 + self.clear(conn) def test_connect_timeout_error_without_retry(self): """Test that the _connect function is not being retried if retry_on_timeout is - set to False (default value)""" - conn = Connection() + set to False""" + conn = Connection(retry_on_timeout=False) conn._connect = mock.Mock() conn._connect.side_effect = socket.timeout @@ -116,3 +121,4 @@ def test_connect_timeout_error_without_retry(self): conn.connect() assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" + self.clear(conn)