Skip to content

Commit

Permalink
Add support for CLIENT SETINFO (#2857)
Browse files Browse the repository at this point in the history
Co-authored-by: Kristján Valur Jónsson <sweskman@gmail.com>
Co-authored-by: Chayim <chayim@users.noreply.github.com>
Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
  • Loading branch information
4 people committed Aug 9, 2023
1 parent d5c2d1d commit f121cf2
Show file tree
Hide file tree
Showing 18 changed files with 181 additions and 25 deletions.
9 changes: 7 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ exclude =
whitelist.py,
tasks.py
ignore =
E126
E203
F405
N801
N802
N803
N806
N815
W503
E203
E126
4 changes: 2 additions & 2 deletions redis/_parsers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ def parse_client_info(value):
"key1=value1 key2=value2 key3=value3"
"""
client_info = {}
infos = str_if_bytes(value).split(" ")
for info in infos:
for info in str_if_bytes(value).strip().split():
key, value = info.split("=")
client_info[key] = value

Expand Down Expand Up @@ -700,6 +699,7 @@ def string_keys_to_dict(key_string, callback):
"CLIENT KILL": parse_client_kill,
"CLIENT LIST": parse_client_list,
"CLIENT PAUSE": bool_ok,
"CLIENT SETINFO": bool_ok,
"CLIENT SETNAME": bool_ok,
"CLIENT UNBLOCK": bool,
"CLUSTER ADDSLOTS": bool_ok,
Expand Down
12 changes: 11 additions & 1 deletion redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
get_lib_version,
safe_str,
str_if_bytes,
)

PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
Expand Down Expand Up @@ -190,6 +196,8 @@ def __init__(
single_connection_client: bool = False,
health_check_interval: int = 0,
client_name: Optional[str] = None,
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
retry: Optional[Retry] = None,
auto_close_connection_pool: bool = True,
Expand Down Expand Up @@ -232,6 +240,8 @@ def __init__(
"max_connections": max_connections,
"health_check_interval": health_check_interval,
"client_name": client_name,
"lib_name": lib_name,
"lib_version": lib_version,
"redis_connect_func": redis_connect_func,
"protocol": protocol,
}
Expand Down
6 changes: 5 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
TryAgainError,
)
from redis.typing import AnyKeyT, EncodableT, KeyT
from redis.utils import dict_merge, safe_str, str_if_bytes
from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes

TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
Expand Down Expand Up @@ -237,6 +237,8 @@ def __init__(
username: Optional[str] = None,
password: Optional[str] = None,
client_name: Optional[str] = None,
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
# Encoding related kwargs
encoding: str = "utf-8",
encoding_errors: str = "strict",
Expand Down Expand Up @@ -288,6 +290,8 @@ def __init__(
"username": username,
"password": password,
"client_name": client_name,
"lib_name": lib_name,
"lib_version": lib_version,
# Encoding related kwargs
"encoding": encoding,
"encoding_errors": encoding_errors,
Expand Down
24 changes: 22 additions & 2 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
TimeoutError,
)
from redis.typing import EncodableT
from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes

from .._parsers import (
BaseParser,
Expand Down Expand Up @@ -101,6 +101,8 @@ class AbstractConnection:
"db",
"username",
"client_name",
"lib_name",
"lib_version",
"credential_provider",
"password",
"socket_timeout",
Expand Down Expand Up @@ -140,6 +142,8 @@ def __init__(
socket_read_size: int = 65536,
health_check_interval: float = 0,
client_name: Optional[str] = None,
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
retry: Optional[Retry] = None,
redis_connect_func: Optional[ConnectCallbackT] = None,
Expand All @@ -157,6 +161,8 @@ def __init__(
self.pid = os.getpid()
self.db = db
self.client_name = client_name
self.lib_name = lib_name
self.lib_version = lib_version
self.credential_provider = credential_provider
self.password = password
self.username = username
Expand Down Expand Up @@ -347,9 +353,23 @@ async def on_connect(self) -> None:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

# if a database is specified, switch to it
# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
if self.lib_version:
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)

# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
try:
await self.read_response()
except ResponseError:
pass

if self.db:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

Expand Down
12 changes: 11 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
)
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
get_lib_version,
safe_str,
str_if_bytes,
)

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -171,6 +177,8 @@ def __init__(
single_connection_client=False,
health_check_interval=0,
client_name=None,
lib_name="redis-py",
lib_version=get_lib_version(),
username=None,
retry=None,
redis_connect_func=None,
Expand Down Expand Up @@ -222,6 +230,8 @@ def __init__(
"max_connections": max_connections,
"health_check_interval": health_check_interval,
"client_name": client_name,
"lib_name": lib_name,
"lib_version": lib_version,
"redis_connect_func": redis_connect_func,
"credential_provider": credential_provider,
"protocol": protocol,
Expand Down
3 changes: 3 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def parse_cluster_myshardid(resp, **options):
"encoding_errors",
"errors",
"host",
"lib_name",
"lib_version",
"max_connections",
"nodes_flag",
"redis_connect_func",
Expand Down Expand Up @@ -225,6 +227,7 @@ class AbstractRedisCluster:
"ACL WHOAMI",
"AUTH",
"CLIENT LIST",
"CLIENT SETINFO",
"CLIENT SETNAME",
"CLIENT GETNAME",
"CONFIG SET",
Expand Down
7 changes: 7 additions & 0 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,13 @@ def client_setname(self, name: str, **kwargs) -> ResponseT:
"""
return self.execute_command("CLIENT SETNAME", name, **kwargs)

def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT:
"""
Sets the current connection library name or version
For mor information see https://redis.io/commands/client-setinfo
"""
return self.execute_command("CLIENT SETINFO", attr, value, **kwargs)

def client_unblock(
self, client_id: int, error: bool = False, **kwargs
) -> ResponseT:
Expand Down
20 changes: 20 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HIREDIS_AVAILABLE,
HIREDIS_PACK_AVAILABLE,
SSL_AVAILABLE,
get_lib_version,
str_if_bytes,
)

Expand Down Expand Up @@ -140,6 +141,8 @@ def __init__(
socket_read_size=65536,
health_check_interval=0,
client_name=None,
lib_name="redis-py",
lib_version=get_lib_version(),
username=None,
retry=None,
redis_connect_func=None,
Expand All @@ -164,6 +167,8 @@ def __init__(
self.pid = os.getpid()
self.db = db
self.client_name = client_name
self.lib_name = lib_name
self.lib_version = lib_version
self.credential_provider = credential_provider
self.password = password
self.username = username
Expand Down Expand Up @@ -360,6 +365,21 @@ def on_connect(self):
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Error setting client name")

try:
# set the library name and version
if self.lib_name:
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
self.read_response()
except ResponseError:
pass

try:
if self.lib_version:
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
self.read_response()
except ResponseError:
pass

# if a database is specified, switch to it
if self.db:
self.send_command("SELECT", self.db)
Expand Down
14 changes: 14 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, Mapping, Union
Expand Down Expand Up @@ -27,6 +28,11 @@
except ImportError:
CRYPTOGRAPHY_AVAILABLE = False

if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata


def from_url(url, **kwargs):
"""
Expand Down Expand Up @@ -131,3 +137,11 @@ def _set_info_logger():
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)


def get_lib_version():
try:
libver = metadata.version("redis")
except metadata.PackageNotFoundError:
libver = "99.99.99"
return libver
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def pytest_sessionstart(session):
enterprise = info["enterprise"]
except redis.ConnectionError:
# provide optimistic defaults
info = {}
version = "10.0.0"
arch_bits = 64
cluster_enabled = False
Expand All @@ -145,9 +146,7 @@ def pytest_sessionstart(session):
# module info
try:
REDIS_INFO["modules"] = info["modules"]
except redis.exceptions.ConnectionError:
pass
except KeyError:
except (KeyError, redis.exceptions.ConnectionError):
pass

if cluster_enabled:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ async def __aexit__(self, exc_type, exc_inst, tb):

def asynccontextmanager(func):
return _asynccontextmanager(func)


# helpers to get the connection arguments for this run
@pytest.fixture()
def redis_url(request):
return request.config.getoption("--redis-url")


@pytest.fixture()
def connect_args(request):
url = request.config.getoption("--redis-url")
return parse_url(url)
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,7 @@ async def test_acl_log(
await user_client.hset("{cache}:0", "hkey", "hval")

assert isinstance(await r.acl_log(target_nodes=node), list)
assert len(await r.acl_log(target_nodes=node)) == 2
assert len(await r.acl_log(target_nodes=node)) == 3
assert len(await r.acl_log(count=1, target_nodes=node)) == 1
assert isinstance((await r.acl_log(target_nodes=node))[0], dict)
assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0]
Expand Down
22 changes: 21 additions & 1 deletion tests/test_asyncio/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def test_acl_log(self, r_teardown, create_redis):
await user_client.hset("cache:0", "hkey", "hval")

assert isinstance(await r.acl_log(), list)
assert len(await r.acl_log()) == 2
assert len(await r.acl_log()) == 3
assert len(await r.acl_log(count=1)) == 1
assert isinstance((await r.acl_log())[0], dict)
expected = (await r.acl_log(count=1))[0]
Expand Down Expand Up @@ -355,6 +355,26 @@ async def test_client_setname(self, r: redis.Redis):
r, await r.client_getname(), "redis_py_test", b"redis_py_test"
)

@skip_if_server_version_lt("7.2.0")
async def test_client_setinfo(self, r: redis.Redis):
await r.ping()
info = await r.client_info()
assert info["lib-name"] == "redis-py"
assert info["lib-ver"] == redis.__version__
assert await r.client_setinfo("lib-name", "test")
assert await r.client_setinfo("lib-ver", "123")
info = await r.client_info()
assert info["lib-name"] == "test"
assert info["lib-ver"] == "123"
r2 = redis.asyncio.Redis(lib_name="test2", lib_version="1234")
info = await r2.client_info()
assert info["lib-name"] == "test2"
assert info["lib-ver"] == "1234"
r3 = redis.asyncio.Redis(lib_name=None, lib_version=None)
info = await r3.client_info()
assert info["lib-name"] == ""
assert info["lib-ver"] == ""

@skip_if_server_version_lt("2.6.9")
@pytest.mark.onlynoncluster
async def test_client_kill(self, r: redis.Redis, r2):
Expand Down

0 comments on commit f121cf2

Please sign in to comment.