From a96a38a0bb5aa05f22ad6fa3a3f5235e70b46ee3 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Mon, 24 Apr 2023 15:49:27 +0300 Subject: [PATCH] Add support for PubSub with RESP3 parser (#2721) * add resp3 pubsub * linters * _set_info_logger func * async pubsun * docstring --- redis/asyncio/client.py | 20 ++++++-- redis/asyncio/connection.py | 16 +++++- redis/client.py | 16 ++++-- redis/connection.py | 12 +++-- redis/parsers/resp3.py | 81 ++++++++++++++++++++++++++++--- redis/utils.py | 14 ++++++ tests/test_asyncio/test_pubsub.py | 31 ++++++++++-- tests/test_pubsub.py | 37 +++++++++++--- 8 files changed, 197 insertions(+), 30 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index ffd68c14d0..5ef1f3292e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -57,7 +57,7 @@ WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT -from redis.utils import safe_str, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -658,6 +658,7 @@ def __init__( shard_hint: Optional[str] = None, ignore_subscribe_messages: bool = False, encoder=None, + push_handler_func: Optional[Callable] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -666,6 +667,7 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: @@ -678,6 +680,8 @@ def __init__( b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] + if self.push_handler_func is None: + _set_info_logger() self.channels = {} self.pending_unsubscribe_channels = set() self.patterns = {} @@ -757,6 +761,8 @@ async def connect(self): self.connection.register_connect_callback(self.on_connect) else: await self.connection.connect() + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ @@ -797,7 +803,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, push_request=True + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it @@ -927,8 +935,8 @@ def ping(self, message=None) -> Awaitable: """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) async def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -936,6 +944,10 @@ async def handle_message(self, response, ignore_subscribe_messages=False): with a message handler, the handler is invoked instead of a parsed message being returned. """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d9c95834d5..bc872ff358 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -485,15 +485,29 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + push_request: Optional[bool] = False, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout try: - if read_timeout is not None: + if ( + read_timeout is not None + and self.protocol == "3" + and not HIREDIS_AVAILABLE + ): + async with async_timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + elif read_timeout is not None: async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + elif self.protocol == "3" and not HIREDIS_AVAILABLE: + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) else: response = await self._parser.read_response( disable_decoding=disable_decoding diff --git a/redis/client.py b/redis/client.py index 15dddc9bd7..71048f548f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -27,7 +27,7 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.utils import safe_str, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -1429,6 +1429,7 @@ def __init__( shard_hint=None, ignore_subscribe_messages=False, encoder=None, + push_handler_func=None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -1438,6 +1439,7 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) @@ -1445,6 +1447,8 @@ def __init__( self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: self.health_check_response = [b"pong", self.health_check_response_b] + if self.push_handler_func is None: + _set_info_logger() self.reset() def __enter__(self): @@ -1515,6 +1519,8 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: @@ -1580,7 +1586,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(push_request=True) response = self._execute(conn, try_read) @@ -1739,8 +1745,8 @@ def ping(self, message=None): """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -1750,6 +1756,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): """ if response is None: return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { diff --git a/redis/connection.py b/redis/connection.py index 85509f7ef7..19c80e08f5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -406,13 +406,18 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False): + def read_response(self, disable_decoding=False, push_request=False): """Read the response from a previously sent command""" host_error = self._host_error() try: - response = self._parser.read_response(disable_decoding=disable_decoding) + if self.protocol == "3" and not HIREDIS_AVAILABLE: + response = self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") @@ -705,8 +710,9 @@ def _connect(self): class UnixDomainSocketConnection(AbstractConnection): "Manages UDS communication to and from a Redis server" - def __init__(self, path="", **kwargs): + def __init__(self, path="", socket_timeout=None, **kwargs): self.path = path + self.socket_timeout = socket_timeout super().__init__(**kwargs) def repr_pieces(self): diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index 2753d39f1a..93fb6ff554 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -1,3 +1,4 @@ +from logging import getLogger from typing import Any, Union from ..exceptions import ConnectionError, InvalidResponse, ResponseError @@ -9,10 +10,21 @@ class _RESP3Parser(_RESPBase): """RESP3 protocol implementation""" - def read_response(self, disable_decoding=False): + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + def read_response(self, disable_decoding=False, push_request=False): pos = self._buffer.get_pos() try: - result = self._read_response(disable_decoding=disable_decoding) + result = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) except BaseException: self._buffer.rewind(pos) raise @@ -20,7 +32,7 @@ def read_response(self, disable_decoding=False): self._buffer.purge() return result - def _read_response(self, disable_decoding=False): + def _read_response(self, disable_decoding=False, push_request=False): raw = self._buffer.readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) @@ -77,9 +89,26 @@ def _read_response(self, disable_decoding=False): response = { self._read_response( disable_decoding=disable_decoding - ): self._read_response(disable_decoding=disable_decoding) + ): self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) for _ in range(int(response)) } + # push response + elif byte == b">": + response = [ + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -87,21 +116,37 @@ def _read_response(self, disable_decoding=False): response = self.encoder.decode(response) return response + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func + class _AsyncRESP3Parser(_AsyncRESPBase): - async def read_response(self, disable_decoding: bool = False): + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + async def read_response( + self, disable_decoding: bool = False, push_request: bool = False + ): if self._chunks: # augment parsing buffer with previously read data self._buffer += b"".join(self._chunks) self._chunks.clear() self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) + response = await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) # Successfully parsing a response allows us to clear our parsing buffer self._clear() return response async def _read_response( - self, disable_decoding: bool = False + self, disable_decoding: bool = False, push_request: bool = False ) -> Union[EncodableT, ResponseError, None]: if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) @@ -166,9 +211,31 @@ async def _read_response( ) for _ in range(int(response)) } + # push response + elif byte == b">": + response = [ + ( + await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return await ( + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + else: + return res else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func diff --git a/redis/utils.py b/redis/utils.py index a6e620088b..148d15246b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,3 +1,4 @@ +import logging from contextlib import contextmanager from functools import wraps from typing import Any, Dict, Mapping, Union @@ -117,3 +118,16 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def _set_info_logger(): + """ + Set up a logger that log info logs to stdout. + (This is used by the default push response handler) + """ + if "push_response" not in logging.root.manager.loggerDict.keys(): + logger = logging.getLogger("push_response") + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 0c0b7dbca6..8cd5cf6fba 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -16,9 +16,11 @@ import redis.asyncio as redis from redis.exceptions import ConnectionError from redis.typing import EncodableT +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt from .compat import create_task, mock +from .conftest import get_protocol_version def with_timeout(t): @@ -420,6 +422,23 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): assert expect in info.exconly() +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + async def test_push_handler(self, r): + if get_protocol_version(r) in [2, "2", None]: + return + p = r.pubsub(push_handler_func=self.my_handler) + await p.subscribe("foo") + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + @pytest.mark.onlynoncluster class TestPubSubAutoDecoding: """These tests only validate that we get unicode values back""" @@ -995,13 +1014,15 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1: + with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1, patch( + "redis.parsers._AsyncHiredisParser.read_response" + ) as mock2, patch("redis.parsers._AsyncRESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.parsers._AsyncHiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - await get_msg() + with pytest.raises(BaseException): + await get_msg() # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 48c0f3ac47..e1e4311511 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -10,8 +10,14 @@ import redis from redis.exceptions import ConnectionError +from redis.utils import HIREDIS_AVAILABLE -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + _get_client, + is_resp2_connection, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): @@ -352,6 +358,23 @@ def test_unicode_pattern_message_handler(self, r): ) +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + def test_push_handler(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.subscribe("foo") + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert r.publish("foo", "test message") == 1 + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -767,13 +790,15 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.parsers._RESP2Parser.read_response") as mock1: + with patch("redis.parsers._RESP2Parser.read_response") as mock1, patch( + "redis.parsers._HiredisParser.read_response" + ) as mock2, patch("redis.parsers._RESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.parsers._HiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - get_msg() + with pytest.raises(BaseException): + get_msg() # the timeout on the read should not cause disconnect assert is_connected()