From d2f5ac5bb397668795979cd5550a92d72f9f49ab Mon Sep 17 00:00:00 2001 From: Nathaniel Braun Date: Wed, 13 Oct 2021 10:16:09 +0000 Subject: [PATCH] Fix `retry` attribute in UnixDomainSocketConnection --- redis/connection.py | 16 +++++++++++++++- tests/test_retry.py | 16 +++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index de30f0c638..5528589026 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -880,7 +880,13 @@ def __init__(self, path='', db=0, username=None, password=None, encoding_errors='strict', decode_responses=False, retry_on_timeout=False, parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None): + health_check_interval=0, client_name=None, + retry=None): + """ + Initialize a new UnixDomainSocketConnection. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ self.pid = os.getpid() self.path = path self.db = db @@ -889,6 +895,14 @@ def __init__(self, path='', db=0, username=None, password=None, self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) diff --git a/tests/test_retry.py b/tests/test_retry.py index 24d9683f17..535485acae 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -2,7 +2,7 @@ import pytest from redis.exceptions import ConnectionError -from redis.connection import Connection +from redis.connection import Connection, UnixDomainSocketConnection from redis.retry import Retry @@ -20,20 +20,22 @@ def compute(self, failures): class TestConnectionConstructorWithRetry: - "Test that the Connection constructor properly handles Retry objects" + "Test that the Connection constructors properly handles Retry objects" @pytest.mark.parametrize("retry_on_timeout", [False, True]) - def test_retry_on_timeout_boolean(self, retry_on_timeout): - c = Connection(retry_on_timeout=retry_on_timeout) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): + c = Class(retry_on_timeout=retry_on_timeout) assert c.retry_on_timeout == retry_on_timeout assert isinstance(c.retry, Retry) assert c.retry._retries == (1 if retry_on_timeout else 0) @pytest.mark.parametrize("retries", range(10)) - def test_retry_on_timeout_retry(self, retries): + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_retry(self, Class, retries): retry_on_timeout = retries > 0 - c = Connection(retry_on_timeout=retry_on_timeout, - retry=Retry(NoBackoff(), retries)) + c = Class(retry_on_timeout=retry_on_timeout, + retry=Retry(NoBackoff(), retries)) assert c.retry_on_timeout == retry_on_timeout assert isinstance(c.retry, Retry) assert c.retry._retries == retries