From 31a4cf407685d69a27feb0c164ccba501e683ce8 Mon Sep 17 00:00:00 2001 From: Elay Gelbart Date: Sat, 15 Jun 2024 18:28:00 +0300 Subject: [PATCH] refactor: Add and update type annotations for core Redis commands --- redis/commands/core.py | 940 ++++++++++++++++++++--------------------- redis/typing.py | 17 +- 2 files changed, 461 insertions(+), 496 deletions(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index a56b3d2cb..710d1cd99 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,12 +1,10 @@ -# from __future__ import annotations - import datetime import hashlib import warnings from typing import ( TYPE_CHECKING, + Any, AsyncIterator, - Awaitable, Callable, Dict, Iterable, @@ -16,16 +14,18 @@ Mapping, Optional, Sequence, - Set, Tuple, Union, ) from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from redis.typing import ( + OKT, AbsExpiryT, AnyKeyT, + ArrayResponseT, BitfieldOffsetT, + BulkStringResponseT, ChannelT, CommandsProtocol, ConsumerT, @@ -33,6 +33,7 @@ ExpiryT, FieldT, GroupT, + IntegerResponseT, KeysT, KeyT, PatternT, @@ -56,7 +57,9 @@ class ACLCommands(CommandsProtocol): see: https://redis.io/topics/acl """ - def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: + def acl_cat( + self, category: Union[str, None] = None, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Returns a list of categories or commands within a category. @@ -66,10 +69,12 @@ def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: For more information see https://redis.io/commands/acl-cat """ - pieces: list[EncodableT] = [category] if category else [] + pieces: List[EncodableT] = [category] if category else [] return self.execute_command("ACL CAT", *pieces, **kwargs) - def acl_dryrun(self, username, *args, **kwargs): + def acl_dryrun( + self, username, *args, **kwargs + ) -> ResponseT[Union[BulkStringResponseT, OKT]]: """ Simulate the execution of a given command by a given ``username``. @@ -77,7 +82,7 @@ def acl_dryrun(self, username, *args, **kwargs): """ return self.execute_command("ACL DRYRUN", username, *args, **kwargs) - def acl_deluser(self, *username: str, **kwargs) -> ResponseT: + def acl_deluser(self, *username: str, **kwargs) -> ResponseT[IntegerResponseT]: """ Delete the ACL for the specified ``username``s @@ -85,7 +90,9 @@ def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_genpass( + self, bits: Union[int, None] = None, **kwargs + ) -> ResponseT[BulkStringResponseT]: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -104,7 +111,9 @@ def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: ) return self.execute_command("ACL GENPASS", *pieces, **kwargs) - def acl_getuser(self, username: str, **kwargs) -> ResponseT: + def acl_getuser( + self, username: str, **kwargs + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Get the ACL details for the specified ``username``. @@ -114,7 +123,7 @@ def acl_getuser(self, username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL GETUSER", username, **kwargs) - def acl_help(self, **kwargs) -> ResponseT: + def acl_help(self, **kwargs) -> ResponseT[ArrayResponseT]: """The ACL HELP command returns helpful text describing the different subcommands. @@ -122,7 +131,7 @@ def acl_help(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL HELP", **kwargs) - def acl_list(self, **kwargs) -> ResponseT: + def acl_list(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a list of all ACLs on the server @@ -130,7 +139,9 @@ def acl_list(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_log( + self, count: Union[int, None] = None, **kwargs + ) -> ResponseT[Union[ArrayResponseT, OKT]]: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -143,10 +154,9 @@ def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: if not isinstance(count, int): raise DataError("ACL LOG count must be an integer") args.append(count) - return self.execute_command("ACL LOG", *args, **kwargs) - def acl_log_reset(self, **kwargs) -> ResponseT: + def acl_log_reset(self, **kwargs) -> ResponseT[Union[ArrayResponseT, OKT]]: """ Reset ACL logs. :rtype: Boolean. @@ -156,7 +166,7 @@ def acl_log_reset(self, **kwargs) -> ResponseT: args = [b"RESET"] return self.execute_command("ACL LOG", *args, **kwargs) - def acl_load(self, **kwargs) -> ResponseT: + def acl_load(self, **kwargs) -> ResponseT[OKT]: """ Load ACL rules from the configured ``aclfile``. @@ -167,7 +177,7 @@ def acl_load(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LOAD", **kwargs) - def acl_save(self, **kwargs) -> ResponseT: + def acl_save(self, **kwargs) -> ResponseT[OKT]: """ Save ACL rules to the configured ``aclfile``. @@ -195,7 +205,7 @@ def acl_setuser( reset_channels: bool = False, reset_passwords: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create or update an ACL user. @@ -254,29 +264,22 @@ def acl_setuser( """ encoder = self.get_encoder() pieces: List[EncodableT] = [username] - if reset: pieces.append(b"reset") - if reset_keys: pieces.append(b"resetkeys") - if reset_channels: pieces.append(b"resetchannels") - if reset_passwords: pieces.append(b"resetpass") - if enabled: pieces.append(b"on") else: pieces.append(b"off") - if (passwords or hashed_passwords) and nopass: raise DataError( "Cannot set 'nopass' and supply 'passwords' or 'hashed_passwords'" ) - if passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list @@ -292,7 +295,6 @@ def acl_setuser( f"Password {i} must be prefixed with a " f'"+" to add or a "-" to remove' ) - if hashed_passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list @@ -308,10 +310,8 @@ def acl_setuser( f"Hashed password {i} must be prefixed with a " f'"+" to add or a "-" to remove' ) - if nopass: pieces.append(b"nopass") - if categories: for category in categories: category = encoder.encode(category) @@ -338,19 +338,16 @@ def acl_setuser( 'must be prefixed with "+" or "-"' ) pieces.append(cmd) - if keys: for key in keys: key = encoder.encode(key) if not key.startswith(b"%") and not key.startswith(b"~"): key = b"~%s" % key pieces.append(key) - if channels: for channel in channels: channel = encoder.encode(channel) pieces.append(b"&%s" % channel) - if selectors: for cmd, key in selectors: cmd = encoder.encode(cmd) @@ -359,23 +356,20 @@ def acl_setuser( f'Command "{encoder.decode(cmd, force=True)}" ' 'must be prefixed with "+" or "-"' ) - key = encoder.encode(key) if not key.startswith(b"%") and not key.startswith(b"~"): key = b"~%s" % key - pieces.append(b"(%s %s)" % (cmd, key)) - return self.execute_command("ACL SETUSER", *pieces, **kwargs) - def acl_users(self, **kwargs) -> ResponseT: + def acl_users(self, **kwargs) -> ResponseT[ArrayResponseT]: """Returns a list of all registered users on the server. For more information see https://redis.io/commands/acl-users """ return self.execute_command("ACL USERS", **kwargs) - def acl_whoami(self, **kwargs) -> ResponseT: + def acl_whoami(self, **kwargs) -> ResponseT[BulkStringResponseT]: """Get the username for the current connection For more information see https://redis.io/commands/acl-whoami @@ -391,7 +385,9 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password: str, username: Optional[str] = None, **kwargs): + def auth( + self, password: str, username: Optional[str] = None, **kwargs + ) -> ResponseT[OKT]: """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will @@ -404,14 +400,14 @@ def auth(self, password: str, username: Optional[str] = None, **kwargs): pieces.append(password) return self.execute_command("AUTH", *pieces, **kwargs) - def bgrewriteaof(self, **kwargs): + def bgrewriteaof(self, **kwargs) -> ResponseT[str]: """Tell the Redis server to rewrite the AOF file from data in memory. For more information see https://redis.io/commands/bgrewriteaof """ return self.execute_command("BGREWRITEAOF", **kwargs) - def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT[str]: """ Tell the Redis server to save its data to disk. Unlike save(), this method is asynchronous and returns immediately. @@ -423,7 +419,7 @@ def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: pieces.append("SCHEDULE") return self.execute_command("BGSAVE", *pieces, **kwargs) - def role(self) -> ResponseT: + def role(self) -> ResponseT[ArrayResponseT]: """ Provide information on the role of a Redis instance in the context of replication, by returning if the instance @@ -433,7 +429,9 @@ def role(self) -> ResponseT: """ return self.execute_command("ROLE") - def client_kill(self, address: str, **kwargs) -> ResponseT: + def client_kill( + self, address: str, **kwargs + ) -> ResponseT[Union[IntegerResponseT, OKT]]: """Disconnects the client at ``address`` (ip:port) For more information see https://redis.io/commands/client-kill @@ -447,10 +445,10 @@ def client_kill_filter( addr: Union[str, None] = None, skipme: Union[bool, None] = None, laddr: Union[bool, None] = None, - user: str = None, + user: Optional[str] = None, maxage: Union[int, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, OKT]]: """ Disconnects client(s) using a variety of filter options :param _id: Kills a client by its unique ID field @@ -494,7 +492,7 @@ def client_kill_filter( ) return self.execute_command("CLIENT KILL", *args, **kwargs) - def client_info(self, **kwargs) -> ResponseT: + def client_info(self, **kwargs) -> ResponseT[BulkStringResponseT]: """ Returns information and statistics about the current client connection. @@ -505,7 +503,7 @@ def client_info(self, **kwargs) -> ResponseT: def client_list( self, _type: Union[str, None] = None, client_id: List[EncodableT] = [], **kwargs - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. @@ -530,7 +528,7 @@ def client_list( args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) - def client_getname(self, **kwargs) -> ResponseT: + def client_getname(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, None]]: """ Returns the current connection name @@ -538,7 +536,7 @@ def client_getname(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT GETNAME", **kwargs) - def client_getredir(self, **kwargs) -> ResponseT: + def client_getredir(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the ID (an integer) of the client to whom we are redirecting tracking notifications. @@ -549,7 +547,7 @@ def client_getredir(self, **kwargs) -> ResponseT: def client_reply( self, reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], **kwargs - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Enable and disable redis server replies. @@ -571,7 +569,7 @@ def client_reply( raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply, **kwargs) - def client_id(self, **kwargs) -> ResponseT: + def client_id(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the current connection id @@ -587,7 +585,7 @@ def client_tracking_on( optin: bool = False, optout: bool = False, noloop: bool = False, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Turn on the tracking mode. For more information about the options look at client_tracking func. @@ -606,7 +604,7 @@ def client_tracking_off( optin: bool = False, optout: bool = False, noloop: bool = False, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Turn off the tracking mode. For more information about the options look at client_tracking func. @@ -627,7 +625,7 @@ def client_tracking( optout: bool = False, noloop: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Enables the tracking feature of the Redis server, that is used for server assisted client side caching. @@ -657,10 +655,8 @@ def client_tracking( See https://redis.io/commands/client-tracking """ - if len(prefix) != 0 and bcast is False: raise DataError("Prefix can only be used with bcast") - pieces = ["ON"] if on else ["OFF"] if clientid is not None: pieces.extend(["REDIRECT", clientid]) @@ -674,10 +670,9 @@ def client_tracking( pieces.append("OPTOUT") if noloop: pieces.append("NOLOOP") - return self.execute_command("CLIENT TRACKING", *pieces) - def client_trackinginfo(self, **kwargs) -> ResponseT: + def client_trackinginfo(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns the information about the current client connection's use of the server assisted client side cache. @@ -686,7 +681,7 @@ def client_trackinginfo(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT TRACKINGINFO", **kwargs) - def client_setname(self, name: str, **kwargs) -> ResponseT: + def client_setname(self, name: str, **kwargs) -> ResponseT[OKT]: """ Sets the current connection name @@ -700,7 +695,7 @@ def client_setname(self, name: str, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) - def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT[OKT]: """ Sets the current connection library name or version For mor information see https://redis.io/commands/client-setinfo @@ -709,7 +704,7 @@ def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: def client_unblock( self, client_id: int, error: bool = False, **kwargs - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Unblocks a connection by its client id. If ``error`` is True, unblocks the client with a special error message. @@ -723,7 +718,7 @@ def client_unblock( args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT[OKT]: """ Suspend all the Redis clients for the specified amount of time. @@ -748,7 +743,7 @@ def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: args.append("WRITE") return self.execute_command(*args, **kwargs) - def client_unpause(self, **kwargs) -> ResponseT: + def client_unpause(self, **kwargs) -> ResponseT[OKT]: """ Unpause all redis clients @@ -756,7 +751,7 @@ def client_unpause(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT UNPAUSE", **kwargs) - def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: + def client_no_evict(self, mode: str) -> ResponseT[OKT]: """ Sets the client eviction mode for the current connection. @@ -764,7 +759,7 @@ def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-EVICT", mode) - def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: + def client_no_touch(self, mode: str) -> ResponseT[OKT]: """ # The command controls whether commands sent by the client will alter # the LRU/LFU of the keys they access. @@ -775,7 +770,7 @@ def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-TOUCH", mode) - def command(self, **kwargs): + def command(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns dict reply of details about all Redis commands. @@ -788,7 +783,7 @@ def command_info(self, **kwargs) -> None: "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self, **kwargs) -> ResponseT: + def command_count(self, **kwargs) -> ResponseT[IntegerResponseT]: return self.execute_command("COMMAND COUNT", **kwargs) def command_list( @@ -796,7 +791,7 @@ def command_list( module: Optional[str] = None, category: Optional[str] = None, pattern: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return an array of the server's command names. You can use one of the following filters: @@ -813,13 +808,11 @@ def command_list( pieces.extend(["ACLCAT", category]) if pattern is not None: pieces.extend(["PATTERN", pattern]) - if pieces: pieces.insert(0, "FILTERBY") - return self.execute_command("COMMAND LIST", *pieces) - def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str]]]: + def command_getkeysandflags(self, *args: List[str]) -> ResponseT[ArrayResponseT]: """ Returns array of keys from a full Redis command and their usage flags. @@ -827,7 +820,7 @@ def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str] """ return self.execute_command("COMMAND GETKEYSANDFLAGS", *args) - def command_docs(self, *args): + def command_docs(self, *args) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -838,7 +831,7 @@ def command_docs(self, *args): def config_get( self, pattern: PatternT = "*", *args: List[PatternT], **kwargs - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a dictionary of configuration based on the ``pattern`` @@ -852,14 +845,14 @@ def config_set( value: EncodableT, *args: List[Union[KeyT, EncodableT]], **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """Set config item ``name`` with ``value`` For more information see https://redis.io/commands/config-set """ return self.execute_command("CONFIG SET", name, value, *args, **kwargs) - def config_resetstat(self, **kwargs) -> ResponseT: + def config_resetstat(self, **kwargs) -> ResponseT[OKT]: """ Reset runtime statistics @@ -867,7 +860,7 @@ def config_resetstat(self, **kwargs) -> ResponseT: """ return self.execute_command("CONFIG RESETSTAT", **kwargs) - def config_rewrite(self, **kwargs) -> ResponseT: + def config_rewrite(self, **kwargs) -> ResponseT[OKT]: """ Rewrite config file with the minimal change to reflect running config. @@ -875,7 +868,7 @@ def config_rewrite(self, **kwargs) -> ResponseT: """ return self.execute_command("CONFIG REWRITE", **kwargs) - def dbsize(self, **kwargs) -> ResponseT: + def dbsize(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the number of keys in the current database @@ -883,7 +876,7 @@ def dbsize(self, **kwargs) -> ResponseT: """ return self.execute_command("DBSIZE", **kwargs) - def debug_object(self, key: KeyT, **kwargs) -> ResponseT: + def debug_object(self, key: KeyT, **kwargs) -> ResponseT[Any]: """ Returns version specific meta information about a given key @@ -900,7 +893,7 @@ def debug_segfault(self, **kwargs) -> None: """ ) - def echo(self, value: EncodableT, **kwargs) -> ResponseT: + def echo(self, value: EncodableT, **kwargs) -> ResponseT[BulkStringResponseT]: """ Echo the string back from the server @@ -908,7 +901,7 @@ def echo(self, value: EncodableT, **kwargs) -> ResponseT: """ return self.execute_command("ECHO", value, **kwargs) - def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT[OKT]: """ Delete all keys in all databases on the current host. @@ -922,7 +915,7 @@ def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: args.append(b"ASYNC") return self.execute_command("FLUSHALL", *args, **kwargs) - def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT[OKT]: """ Delete all keys in the current database. @@ -936,7 +929,7 @@ def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: args.append(b"ASYNC") return self.execute_command("FLUSHDB", *args, **kwargs) - def sync(self) -> ResponseT: + def sync(self) -> ResponseT[Any]: """ Initiates a replication stream from the master. @@ -948,7 +941,7 @@ def sync(self) -> ResponseT: options[NEVER_DECODE] = [] return self.execute_command("SYNC", **options) - def psync(self, replicationid: str, offset: int): + def psync(self, replicationid: str, offset: int) -> ResponseT[Any]: """ Initiates a replication stream from the master. Newer version for `sync`. @@ -961,7 +954,7 @@ def psync(self, replicationid: str, offset: int): options[NEVER_DECODE] = [] return self.execute_command("PSYNC", replicationid, offset, **options) - def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT[OKT]: """ Swap two databases @@ -969,7 +962,7 @@ def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: """ return self.execute_command("SWAPDB", first, second, **kwargs) - def select(self, index: int, **kwargs) -> ResponseT: + def select(self, index: int, **kwargs) -> ResponseT[OKT]: """Select the Redis logical database at index. See: https://redis.io/commands/select @@ -978,7 +971,7 @@ def select(self, index: int, **kwargs) -> ResponseT: def info( self, section: Union[str, None] = None, *args: List[str], **kwargs - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Returns a dictionary containing information about the Redis server @@ -995,7 +988,7 @@ def info( else: return self.execute_command("INFO", section, *args, **kwargs) - def lastsave(self, **kwargs) -> ResponseT: + def lastsave(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Return a Python datetime object representing the last time the Redis database was saved to disk @@ -1004,7 +997,7 @@ def lastsave(self, **kwargs) -> ResponseT: """ return self.execute_command("LASTSAVE", **kwargs) - def latency_doctor(self): + def latency_doctor(self) -> None: """Raise a NotImplementedError, as the client will not support LATENCY DOCTOR. This funcion is best used within the redis-cli. @@ -1018,7 +1011,7 @@ def latency_doctor(self): """ ) - def latency_graph(self): + def latency_graph(self) -> None: """Raise a NotImplementedError, as the client will not support LATENCY GRAPH. This funcion is best used within the redis-cli. @@ -1032,7 +1025,9 @@ def latency_graph(self): """ ) - def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: + def lolwut( + self, *version_numbers: Union[str, float], **kwargs + ) -> ResponseT[BulkStringResponseT]: """ Get the Redis version and a piece of generative computer art @@ -1043,7 +1038,7 @@ def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: else: return self.execute_command("LOLWUT", **kwargs) - def reset(self) -> ResponseT: + def reset(self) -> ResponseT[str]: """Perform a full reset on the connection's server side contenxt. See: https://redis.io/commands/reset @@ -1061,7 +1056,7 @@ def migrate( replace: bool = False, auth: Union[str, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Union[str, OKT]]: """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -1098,7 +1093,7 @@ def migrate( "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs ) - def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT[Any]: """ Return the encoding, idletime, or refcount about the key """ @@ -1124,7 +1119,7 @@ def memory_help(self, **kwargs) -> None: """ ) - def memory_stats(self, **kwargs) -> ResponseT: + def memory_stats(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a dictionary of memory stats @@ -1132,7 +1127,7 @@ def memory_stats(self, **kwargs) -> ResponseT: """ return self.execute_command("MEMORY STATS", **kwargs) - def memory_malloc_stats(self, **kwargs) -> ResponseT: + def memory_malloc_stats(self, **kwargs) -> ResponseT[BulkStringResponseT]: """ Return an internal statistics report from the memory allocator. @@ -1142,7 +1137,7 @@ def memory_malloc_stats(self, **kwargs) -> ResponseT: def memory_usage( self, key: KeyT, samples: Union[int, None] = None, **kwargs - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, None]]: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -1158,7 +1153,7 @@ def memory_usage( args.extend([b"SAMPLES", samples]) return self.execute_command("MEMORY USAGE", key, *args, **kwargs) - def memory_purge(self, **kwargs) -> ResponseT: + def memory_purge(self, **kwargs) -> ResponseT[OKT]: """ Attempts to purge dirty pages for reclamation by allocator @@ -1166,7 +1161,7 @@ def memory_purge(self, **kwargs) -> ResponseT: """ return self.execute_command("MEMORY PURGE", **kwargs) - def latency_histogram(self, *args): + def latency_histogram(self, *args) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1175,7 +1170,7 @@ def latency_histogram(self, *args): "LATENCY HISTOGRAM is intentionally not implemented in the client." ) - def latency_history(self, event: str) -> ResponseT: + def latency_history(self, event: str) -> ResponseT[ArrayResponseT]: """ Returns the raw data of the ``event``'s latency spikes time series. @@ -1183,7 +1178,7 @@ def latency_history(self, event: str) -> ResponseT: """ return self.execute_command("LATENCY HISTORY", event) - def latency_latest(self) -> ResponseT: + def latency_latest(self) -> ResponseT[ArrayResponseT]: """ Reports the latest latency events logged. @@ -1191,7 +1186,7 @@ def latency_latest(self) -> ResponseT: """ return self.execute_command("LATENCY LATEST") - def latency_reset(self, *events: str) -> ResponseT: + def latency_reset(self, *events: str) -> ResponseT[IntegerResponseT]: """ Resets the latency spikes time series of all, or only some, events. @@ -1199,7 +1194,7 @@ def latency_reset(self, *events: str) -> ResponseT: """ return self.execute_command("LATENCY RESET", *events) - def ping(self, **kwargs) -> ResponseT: + def ping(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, str]]: """ Ping the Redis server @@ -1207,7 +1202,7 @@ def ping(self, **kwargs) -> ResponseT: """ return self.execute_command("PING", **kwargs) - def quit(self, **kwargs) -> ResponseT: + def quit(self, **kwargs) -> ResponseT[OKT]: """ Ask the server to close the connection. @@ -1215,7 +1210,7 @@ def quit(self, **kwargs) -> ResponseT: """ return self.execute_command("QUIT", **kwargs) - def replicaof(self, *args, **kwargs) -> ResponseT: + def replicaof(self, *args, **kwargs) -> ResponseT[OKT]: """ Update the replication settings of a redis replica, on the fly. @@ -1228,7 +1223,7 @@ def replicaof(self, *args, **kwargs) -> ResponseT: """ return self.execute_command("REPLICAOF", *args, **kwargs) - def save(self, **kwargs) -> ResponseT: + def save(self, **kwargs) -> ResponseT[OKT]: """ Tell the Redis server to save its data to disk, blocking until the save is complete @@ -1276,12 +1271,12 @@ def shutdown( self.execute_command(*args, **kwargs) except ConnectionError: # a ConnectionError here is expected - return + return None raise RedisError("SHUTDOWN seems to have failed.") def slaveof( self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1293,7 +1288,9 @@ def slaveof( return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: + def slowlog_get( + self, num: Union[int, None] = None, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1310,7 +1307,7 @@ def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: kwargs[NEVER_DECODE] = [] return self.execute_command(*args, **kwargs) - def slowlog_len(self, **kwargs) -> ResponseT: + def slowlog_len(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Get the number of items in the slowlog @@ -1318,7 +1315,7 @@ def slowlog_len(self, **kwargs) -> ResponseT: """ return self.execute_command("SLOWLOG LEN", **kwargs) - def slowlog_reset(self, **kwargs) -> ResponseT: + def slowlog_reset(self, **kwargs) -> ResponseT[OKT]: """ Remove all items in the slowlog @@ -1326,7 +1323,7 @@ def slowlog_reset(self, **kwargs) -> ResponseT: """ return self.execute_command("SLOWLOG RESET", **kwargs) - def time(self, **kwargs) -> ResponseT: + def time(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). @@ -1335,7 +1332,9 @@ def time(self, **kwargs) -> ResponseT: """ return self.execute_command("TIME", **kwargs) - def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: + def wait( + self, num_replicas: int, timeout: int, **kwargs + ) -> ResponseT[IntegerResponseT]: """ Redis synchronous replication That returns the number of replicas that processed the query when @@ -1348,7 +1347,7 @@ def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: def waitaof( self, num_local: int, num_replicas: int, timeout: int, **kwargs - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ This command blocks the current client until all previous write commands by that client are acknowledged as having been fsynced @@ -1361,7 +1360,7 @@ def waitaof( "WAITAOF", num_local, num_replicas, timeout, **kwargs ) - def hello(self): + def hello(self) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1370,7 +1369,7 @@ def hello(self): "HELLO is intentionally not implemented in the client." ) - def failover(self): + def failover(self) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1380,10 +1379,11 @@ def failover(self): ) -AsyncManagementCommands = ManagementCommands +# AsyncManagementCommands = ManagementCommands class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: return super().command_info(**kwargs) @@ -1430,7 +1430,7 @@ async def shutdown( await self.execute_command(*args, **kwargs) except ConnectionError: # a ConnectionError here is expected - return + return None raise RedisError("SHUTDOWN seems to have failed.") @@ -1449,7 +1449,7 @@ def __init__( self.key = key self._default_overflow = default_overflow # for typing purposes, run the following in constructor and in reset() - self.operations: list[tuple[EncodableT, ...]] = [] + self.operations: List[tuple[EncodableT, ...]] = [] self._last_overflow = "WRAP" self.reset() @@ -1496,7 +1496,6 @@ def incrby( """ if overflow is not None: self.overflow(overflow) - self.operations.append(("INCRBY", fmt, offset, increment)) return self @@ -1534,7 +1533,7 @@ def command(self): cmd.extend(ops) return cmd - def execute(self) -> ResponseT: + def execute(self) -> ResponseT[Any]: """ Execute the operation(s) in a single BITFIELD command. The return value is a list of values corresponding to each operation. If the client @@ -1551,7 +1550,7 @@ class BasicKeyCommands(CommandsProtocol): Redis basic key-based commands """ - def append(self, key: KeyT, value: EncodableT) -> ResponseT: + def append(self, key: KeyT, value: EncodableT) -> ResponseT[IntegerResponseT]: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1567,7 +1566,7 @@ def bitcount( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1602,8 +1601,8 @@ def bitfield_ro( key: KeyT, encoding: str, offset: BitfieldOffsetT, - items: Optional[list] = None, - ) -> ResponseT: + items: Optional[List] = None, + ) -> ResponseT[ArrayResponseT]: """ Return an array of the specified bitfield values where the first value is found using ``encoding`` and ``offset`` @@ -1614,13 +1613,14 @@ def bitfield_ro( For more information see https://redis.io/commands/bitfield_ro """ params = [key, "GET", encoding, offset] - items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) return self.execute_command("BITFIELD_RO", *params, keys=[key]) - def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: + def bitop( + self, operation: str, dest: KeyT, *keys: KeyT + ) -> ResponseT[IntegerResponseT]: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1636,7 +1636,7 @@ def bitpos( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1648,14 +1648,12 @@ def bitpos( if bit not in (0, 1): raise DataError("bit must be 0 or 1") params = [key, bit] - - start is not None and params.append(start) - + if start is not None: + params.append(start) if start is not None and end is not None: params.append(end) elif start is None and end is not None: raise DataError("start argument is not set, when end is specified") - if mode is not None: params.append(mode) return self.execute_command("BITPOS", *params, keys=[key]) @@ -1666,7 +1664,7 @@ def copy( destination: str, destination_db: Union[str, None] = None, replace: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1686,7 +1684,7 @@ def copy( params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT[IntegerResponseT]: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1697,16 +1695,16 @@ def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: decr = decrby - def delete(self, *names: KeyT) -> ResponseT: + def delete(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Delete one or more keys specified by ``names`` """ return self.execute_command("DEL", *names) - def __delitem__(self, name: KeyT): + def __delitem__(self, name: KeyT) -> None: self.delete(name) - def dump(self, name: KeyT) -> ResponseT: + def dump(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1719,7 +1717,7 @@ def dump(self, name: KeyT) -> ResponseT: options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names: KeyT) -> ResponseT: + def exists(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of ``names`` that exist @@ -1737,7 +1735,7 @@ def expire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` for ``time`` seconds with given ``option``. ``time`` can be represented by an integer or a Python timedelta @@ -1753,7 +1751,6 @@ def expire( """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds()) - exp_option = list() if nx: exp_option.append("NX") @@ -1763,7 +1760,6 @@ def expire( exp_option.append("GT") if lt: exp_option.append("LT") - return self.execute_command("EXPIRE", name, time, *exp_option) def expireat( @@ -1774,7 +1770,7 @@ def expireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer indicating unix time or a Python @@ -1790,7 +1786,6 @@ def expireat( """ if isinstance(when, datetime.datetime): when = int(when.timestamp()) - exp_option = list() if nx: exp_option.append("NX") @@ -1800,10 +1795,9 @@ def expireat( exp_option.append("GT") if lt: exp_option.append("LT") - return self.execute_command("EXPIREAT", name, when, *exp_option) - def expiretime(self, key: str) -> int: + def expiretime(self, key: str) -> ResponseT[IntegerResponseT]: """ Returns the absolute Unix timestamp (since January 1, 1970) in seconds at which the given key will expire. @@ -1812,7 +1806,7 @@ def expiretime(self, key: str) -> int: """ return self.execute_command("EXPIRETIME", key) - def get(self, name: KeyT) -> ResponseT: + def get(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1820,7 +1814,7 @@ def get(self, name: KeyT) -> ResponseT: """ return self.execute_command("GET", name, keys=[name]) - def getdel(self, name: KeyT) -> ResponseT: + def getdel(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1839,7 +1833,7 @@ def getex( exat: Union[AbsExpiryT, None] = None, pxat: Union[AbsExpiryT, None] = None, persist: bool = False, - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1860,15 +1854,13 @@ def getex( For more information see https://redis.io/commands/getex """ - opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: raise DataError( "``ex``, ``px``, ``exat``, ``pxat``, " "and ``persist`` are mutually exclusive." ) - - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] # similar to set command if ex is not None: pieces.append("EX") @@ -1893,10 +1885,9 @@ def getex( pieces.append(pxat) if persist: pieces.append("PERSIST") - return self.execute_command("GETEX", name, *pieces) - def __getitem__(self, name: KeyT): + def __getitem__(self, name: KeyT) -> ResponseT[BulkStringResponseT]: """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1906,7 +1897,7 @@ def __getitem__(self, name: KeyT): return value raise KeyError(name) - def getbit(self, name: KeyT, offset: int) -> ResponseT: + def getbit(self, name: KeyT, offset: int) -> ResponseT[IntegerResponseT]: """ Returns an integer indicating the value of ``offset`` in ``name`` @@ -1914,7 +1905,9 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: """ return self.execute_command("GETBIT", name, offset, keys=[name]) - def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: + def getrange( + self, key: KeyT, start: int, end: int + ) -> ResponseT[BulkStringResponseT]: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -1923,7 +1916,9 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ return self.execute_command("GETRANGE", key, start, end, keys=[key]) - def getset(self, name: KeyT, value: EncodableT) -> ResponseT: + def getset( + self, name: KeyT, value: EncodableT + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -1935,7 +1930,7 @@ def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("GETSET", name, value) - def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT[IntegerResponseT]: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1946,7 +1941,9 @@ def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: incr = incrby - def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: + def incrbyfloat( + self, name: KeyT, amount: float = 1.0 + ) -> ResponseT[BulkStringResponseT]: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1955,7 +1952,7 @@ def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT[ArrayResponseT]: """ Returns a list of keys matching ``pattern`` @@ -1965,7 +1962,7 @@ def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: def lmove( self, first_list: str, second_list: str, src: str = "LEFT", dest: str = "RIGHT" - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -1983,7 +1980,7 @@ def blmove( timeout: int, src: str = "LEFT", dest: str = "RIGHT", - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Blocking version of lmove. @@ -1992,7 +1989,7 @@ def blmove( params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT[ArrayResponseT]: """ Returns a list of values ordered identically to ``keys`` @@ -2007,7 +2004,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options["keys"] = args return self.execute_command("MGET", *args, **options) - def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT[OKT]: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -2020,7 +2017,9 @@ def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: items.extend(pair) return self.execute_command("MSET", *items) - def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def msetnx( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> ResponseT[IntegerResponseT]: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -2034,7 +2033,7 @@ def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name: KeyT, db: int) -> ResponseT: + def move(self, name: KeyT, db: int) -> ResponseT[IntegerResponseT]: """ Moves the key ``name`` to a different Redis database ``db`` @@ -2042,7 +2041,7 @@ def move(self, name: KeyT, db: int) -> ResponseT: """ return self.execute_command("MOVE", name, db) - def persist(self, name: KeyT) -> ResponseT: + def persist(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Removes an expiration on ``name`` @@ -2058,7 +2057,7 @@ def pexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` for ``time`` milliseconds with given ``option``. ``time`` can be represented by an @@ -2074,7 +2073,6 @@ def pexpire( """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds() * 1000) - exp_option = list() if nx: exp_option.append("NX") @@ -2094,7 +2092,7 @@ def pexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer representing unix time in @@ -2121,7 +2119,7 @@ def pexpireat( exp_option.append("LT") return self.execute_command("PEXPIREAT", name, when, *exp_option) - def pexpiretime(self, key: str) -> int: + def pexpiretime(self, key: str) -> ResponseT[IntegerResponseT]: """ Returns the absolute Unix timestamp (since January 1, 1970) in milliseconds at which the given key will expire. @@ -2130,7 +2128,7 @@ def pexpiretime(self, key: str) -> int: """ return self.execute_command("PEXPIRETIME", key) - def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): + def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT) -> ResponseT[OKT]: """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -2142,7 +2140,7 @@ def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name: KeyT) -> ResponseT: + def pttl(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of milliseconds until the key ``name`` will expire @@ -2151,8 +2149,8 @@ def pttl(self, name: KeyT) -> ResponseT: return self.execute_command("PTTL", name) def hrandfield( - self, key: str, count: int = None, withvalues: bool = False - ) -> ResponseT: + self, key: str, count: Optional[int] = None, withvalues: bool = False + ) -> ResponseT[Union[BulkStringResponseT, None, ArrayResponseT]]: """ Return a random field from the hash value stored at key. @@ -2171,10 +2169,9 @@ def hrandfield( params.append(count) if withvalues: params.append("WITHVALUES") - return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs) -> ResponseT: + def randomkey(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, None]]: """ Returns the name of a random key @@ -2182,7 +2179,7 @@ def randomkey(self, **kwargs) -> ResponseT: """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src: KeyT, dst: KeyT) -> ResponseT: + def rename(self, src: KeyT, dst: KeyT) -> ResponseT[OKT]: """ Rename key ``src`` to ``dst`` @@ -2190,7 +2187,7 @@ def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src: KeyT, dst: KeyT): + def renamenx(self, src: KeyT, dst: KeyT) -> ResponseT[IntegerResponseT]: """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -2207,7 +2204,7 @@ def restore( absttl: bool = False, idletime: Union[int, None] = None, frequency: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -2238,14 +2235,12 @@ def restore( params.append(int(idletime)) except ValueError: raise DataError("idletimemust be an integer") - if frequency is not None: params.append("FREQ") try: params.append(int(frequency)) except ValueError: raise DataError("frequency must be an integer") - return self.execute_command("RESTORE", *params) def set( @@ -2260,7 +2255,7 @@ def set( get: bool = False, exat: Union[AbsExpiryT, None] = None, pxat: Union[AbsExpiryT, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None, OKT]]: """ Set the value at key ``name`` to ``value`` @@ -2289,7 +2284,7 @@ def set( For more information see https://redis.io/commands/set """ - pieces: list[EncodableT] = [name, value] + pieces: List[EncodableT] = [name, value] options = {} if ex is not None: pieces.append("EX") @@ -2321,22 +2316,21 @@ def set( pieces.append(pxat) if keepttl: pieces.append("KEEPTTL") - if nx: pieces.append("NX") if xx: pieces.append("XX") - if get: pieces.append("GET") options["get"] = True - return self.execute_command("SET", *pieces, **options) - def __setitem__(self, name: KeyT, value: EncodableT): + def __setitem__(self, name: KeyT, value: EncodableT) -> None: self.set(name, value) - def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: + def setbit( + self, name: KeyT, offset: int, value: int + ) -> ResponseT[IntegerResponseT]: """ Flag the ``offset`` in ``name`` as ``value``. Returns an integer indicating the previous value of ``offset``. @@ -2346,7 +2340,7 @@ def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT[OKT]: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -2358,7 +2352,7 @@ def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT[IntegerResponseT]: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -2366,7 +2360,9 @@ def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("SETNX", name, value) - def setrange(self, name: KeyT, offset: int, value: EncodableT) -> ResponseT: + def setrange( + self, name: KeyT, offset: int, value: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -2392,7 +2388,7 @@ def stralgo( minmatchlen: Union[int, None] = None, withmatchlen: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Any]: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -2421,8 +2417,7 @@ def stralgo( raise DataError("specific_argument can be only keys or strings") if len and idx: raise DataError("len and idx cannot be provided together.") - - pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] + pieces: List[EncodableT] = [algo, specific_argument.upper(), value1, value2] if len: pieces.append(b"LEN") if idx: @@ -2434,7 +2429,6 @@ def stralgo( pass if withmatchlen: pieces.append(b"WITHMATCHLEN") - return self.execute_command( "STRALGO", *pieces, @@ -2445,7 +2439,7 @@ def stralgo( **kwargs, ) - def strlen(self, name: KeyT) -> ResponseT: + def strlen(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Return the number of bytes stored in the value of ``name`` @@ -2453,14 +2447,16 @@ def strlen(self, name: KeyT) -> ResponseT: """ return self.execute_command("STRLEN", name, keys=[name]) - def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: + def substr( + self, name: KeyT, start: int, end: int = -1 + ) -> ResponseT[BulkStringResponseT]: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end, keys=[name]) - def touch(self, *args: KeyT) -> ResponseT: + def touch(self, *args: KeyT) -> ResponseT[IntegerResponseT]: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -2469,7 +2465,7 @@ def touch(self, *args: KeyT) -> ResponseT: """ return self.execute_command("TOUCH", *args) - def ttl(self, name: KeyT) -> ResponseT: + def ttl(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of seconds until the key ``name`` will expire @@ -2477,7 +2473,7 @@ def ttl(self, name: KeyT) -> ResponseT: """ return self.execute_command("TTL", name) - def type(self, name: KeyT) -> ResponseT: + def type(self, name: KeyT) -> ResponseT[str]: """ Returns the type of key ``name`` @@ -2501,7 +2497,7 @@ def unwatch(self) -> None: """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names: KeyT) -> ResponseT: + def unlink(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Unlink one or more keys specified by ``names`` @@ -2517,7 +2513,7 @@ def lcs( idx: Optional[bool] = False, minmatchlen: Optional[int] = 0, withmatchlen: Optional[bool] = False, - ) -> Union[str, int, list]: + ) -> ResponseT[Union[BulkStringResponseT, IntegerResponseT, ArrayResponseT]]: """ Find the longest common subsequence between ``key1`` and ``key2``. If ``len`` is true the length of the match will will be returned. @@ -2540,16 +2536,17 @@ def lcs( class AsyncBasicKeyCommands(BasicKeyCommands): - def __delitem__(self, name: KeyT): + + def __delitem__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class deletion") - def __contains__(self, name: KeyT): + def __contains__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class inclusion") - def __getitem__(self, name: KeyT): + def __getitem__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class retrieval") - def __setitem__(self, name: KeyT, value: EncodableT): + def __setitem__(self, name: KeyT, value: EncodableT) -> None: raise TypeError("Async Redis client does not support class assignment") async def watch(self, *names: KeyT) -> None: @@ -2567,7 +2564,7 @@ class ListCommands(CommandsProtocol): def blpop( self, keys: List, timeout: Optional[int] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2588,7 +2585,7 @@ def blpop( def brpop( self, keys: List, timeout: Optional[int] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2609,7 +2606,7 @@ def brpop( def brpoplpush( self, src: str, dst: str, timeout: Optional[int] = 0 - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -2631,7 +2628,7 @@ def blmpop( *args: List[str], direction: str, count: Optional[int] = 1, - ) -> Optional[list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) from first non-empty in the list of provided key names. @@ -2642,16 +2639,11 @@ def blmpop( For more information see https://redis.io/commands/blmpop """ args = [timeout, numkeys, *args, direction, "COUNT", count] - return self.execute_command("BLMPOP", *args) def lmpop( - self, - num_keys: int, - *args: List[str], - direction: str, - count: Optional[int] = 1, - ) -> Union[Awaitable[list], list]: + self, num_keys: int, *args: List[str], direction: str, count: Optional[int] = 1 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) first non-empty list key from the list of args provided key names. @@ -2661,12 +2653,11 @@ def lmpop( args = [num_keys] + list(args) + [direction] if count != 1: args.extend(["COUNT", count]) - return self.execute_command("LMPOP", *args) def lindex( self, name: str, index: int - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the item from list ``name`` at position ``index`` @@ -2679,7 +2670,7 @@ def lindex( def linsert( self, name: str, where: str, refvalue: str, value: str - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Insert ``value`` in list ``name`` either immediately before or after [``where``] ``refvalue`` @@ -2691,7 +2682,7 @@ def linsert( """ return self.execute_command("LINSERT", name, where, refvalue, value) - def llen(self, name: str) -> Union[Awaitable[int], int]: + def llen(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the length of the list ``name`` @@ -2700,10 +2691,8 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: return self.execute_command("LLEN", name, keys=[name]) def lpop( - self, - name: str, - count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[None, BulkStringResponseT, ArrayResponseT]]: """ Removes and returns the first elements of the list ``name``. @@ -2718,7 +2707,7 @@ def lpop( else: return self.execute_command("LPOP", name) - def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpush(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``values`` onto the head of the list ``name`` @@ -2726,7 +2715,7 @@ def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSH", name, *values) - def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpushx(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``value`` onto the head of the list ``name`` if ``name`` exists @@ -2734,7 +2723,7 @@ def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSHX", name, *values) - def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list]: + def lrange(self, name: str, start: int, end: int) -> ResponseT[ArrayResponseT]: """ Return a slice of the list ``name`` between position ``start`` and ``end`` @@ -2746,7 +2735,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list """ return self.execute_command("LRANGE", name, start, end, keys=[name]) - def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: + def lrem(self, name: str, count: int, value: str) -> ResponseT[IntegerResponseT]: """ Remove the first ``count`` occurrences of elements equal to ``value`` from the list stored at ``name``. @@ -2760,7 +2749,7 @@ def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ return self.execute_command("LREM", name, count, value) - def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: + def lset(self, name: str, index: int, value: str) -> ResponseT[OKT]: """ Set element at ``index`` of list ``name`` to ``value`` @@ -2768,7 +2757,7 @@ def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: """ return self.execute_command("LSET", name, index, value) - def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: + def ltrim(self, name: str, start: int, end: int) -> ResponseT[OKT]: """ Trim the list ``name``, removing all values not within the slice between ``start`` and ``end`` @@ -2781,10 +2770,8 @@ def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: return self.execute_command("LTRIM", name, start, end) def rpop( - self, - name: str, - count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[None, BulkStringResponseT, ArrayResponseT]]: """ Removes and returns the last elements of the list ``name``. @@ -2799,7 +2786,9 @@ def rpop( else: return self.execute_command("RPOP", name) - def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: + def rpoplpush( + self, src: str, dst: str + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ RPOP a value off of the ``src`` list and atomically LPUSH it on to the ``dst`` list. Returns the value. @@ -2808,7 +2797,7 @@ def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: """ return self.execute_command("RPOPLPUSH", src, dst) - def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def rpush(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``values`` onto the tail of the list ``name`` @@ -2816,7 +2805,7 @@ def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name: str, *values: str) -> Union[Awaitable[int], int]: + def rpushx(self, name: str, *values: str) -> ResponseT[IntegerResponseT]: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists @@ -2831,7 +2820,7 @@ def lpos( rank: Optional[int] = None, count: Optional[int] = None, maxlen: Optional[int] = None, - ) -> Union[str, List, None]: + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Get position of ``value`` within the list ``name`` @@ -2857,16 +2846,13 @@ def lpos( For more information see https://redis.io/commands/lpos """ - pieces: list[EncodableT] = [name, value] + pieces: List[EncodableT] = [name, value] if rank is not None: pieces.extend(["RANK", rank]) - if count is not None: pieces.extend(["COUNT", count]) - if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces, keys=[name]) def sort( @@ -2880,7 +2866,7 @@ def sort( alpha: bool = False, store: Optional[str] = None, groups: Optional[bool] = False, - ) -> Union[List, int]: + ) -> ResponseT[ArrayResponseT]: """ Sort and return the list, set or sorted set at ``name``. @@ -2908,8 +2894,7 @@ def sort( """ if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - - pieces: list[EncodableT] = [name] + pieces: List[EncodableT] = [name] if by is not None: pieces.extend([b"BY", by]) if start is not None and num is not None: @@ -2937,7 +2922,6 @@ def sort( "must be specified and contain at least " "two keys" ) - options = {"groups": len(get) if groups else None} options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) @@ -2951,7 +2935,7 @@ def sort_ro( get: Optional[List[str]] = None, desc: bool = False, alpha: bool = False, - ) -> list: + ) -> ResponseT[ArrayResponseT]: """ Returns the elements contained in the list, set or sorted set at key. (read-only variant of the SORT command) @@ -2992,7 +2976,7 @@ def scan( count: Union[int, None] = None, _type: Union[str, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -3009,7 +2993,7 @@ def scan( For more information see https://redis.io/commands/scan """ - pieces: list[EncodableT] = [cursor] + pieces: List[EncodableT] = [cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3052,7 +3036,7 @@ def sscan( cursor: int = 0, match: Union[PatternT, None] = None, count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of elements in a set. Also return a cursor indicating the scan position. @@ -3063,7 +3047,7 @@ def sscan( For more information see https://redis.io/commands/sscan """ - pieces: list[EncodableT] = [name, cursor] + pieces: List[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3096,7 +3080,7 @@ def hscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, no_values: Union[bool, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return key/value slices in a hash. Also return a cursor indicating the scan position. @@ -3109,7 +3093,7 @@ def hscan( For more information see https://redis.io/commands/hscan """ - pieces: list[EncodableT] = [name, cursor] + pieces: List[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3152,7 +3136,7 @@ def zscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -3203,6 +3187,7 @@ def zscan_iter( class AsyncScanCommands(ScanCommands): + async def scan_iter( self, match: Union[PatternT, None] = None, @@ -3319,7 +3304,7 @@ class SetCommands(CommandsProtocol): see: https://redis.io/topics/data-types#sets """ - def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def sadd(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Add ``value(s)`` to set ``name`` @@ -3327,7 +3312,7 @@ def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SADD", name, *values) - def scard(self, name: str) -> Union[Awaitable[int], int]: + def scard(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the number of elements in set ``name`` @@ -3335,7 +3320,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: """ return self.execute_command("SCARD", name, keys=[name]) - def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sdiff(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the difference of sets specified by ``keys`` @@ -3346,7 +3331,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: def sdiffstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3356,7 +3341,7 @@ def sdiffstore( args = list_or_args(keys, args) return self.execute_command("SDIFFSTORE", dest, *args) - def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sinter(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the intersection of sets specified by ``keys`` @@ -3367,7 +3352,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Return the cardinality of the intersect of multiple sets specified by ``keys`. @@ -3382,7 +3367,7 @@ def sintercard( def sinterstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3392,9 +3377,7 @@ def sinterstore( args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember( - self, name: str, value: str - ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: + def sismember(self, name: str, value: str) -> ResponseT[IntegerResponseT]: """ Return whether ``value`` is a member of set ``name``: - 1 if the value is a member of the set. @@ -3404,7 +3387,7 @@ def sismember( """ return self.execute_command("SISMEMBER", name, value, keys=[name]) - def smembers(self, name: str) -> Union[Awaitable[Set], Set]: + def smembers(self, name: str) -> ResponseT[ArrayResponseT]: """ Return all members of the set ``name`` @@ -3412,10 +3395,9 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ return self.execute_command("SMEMBERS", name, keys=[name]) - def smismember(self, name: str, values: List, *args: List) -> Union[ - Awaitable[List[Union[Literal[0], Literal[1]]]], - List[Union[Literal[0], Literal[1]]], - ]: + def smismember( + self, name: str, values: List, *args: List + ) -> ResponseT[ArrayResponseT]: """ Return whether each value in ``values`` is a member of the set ``name`` as a list of ``int`` in the order of ``values``: @@ -3427,7 +3409,7 @@ def smismember(self, name: str, values: List, *args: List) -> Union[ args = list_or_args(values, args) return self.execute_command("SMISMEMBER", name, *args, keys=[name]) - def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: + def smove(self, src: str, dst: str, value: str) -> ResponseT[IntegerResponseT]: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -3435,7 +3417,9 @@ def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def spop( + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[BulkStringResponseT, None, ArrayResponseT]]: """ Remove and return a random member of set ``name`` @@ -3446,7 +3430,7 @@ def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None] def srandmember( self, name: str, number: Optional[int] = None - ) -> Union[str, List, None]: + ) -> ResponseT[Union[BulkStringResponseT, ArrayResponseT]]: """ If ``number`` is None, returns a random member of set ``name``. @@ -3459,7 +3443,7 @@ def srandmember( args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def srem(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Remove ``values`` from set ``name`` @@ -3467,7 +3451,7 @@ def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SREM", name, *values) - def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: + def sunion(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the union of sets specified by ``keys`` @@ -3478,7 +3462,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: def sunionstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3498,7 +3482,9 @@ class StreamCommands(CommandsProtocol): see: https://redis.io/topics/streams-intro """ - def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: + def xack( + self, name: KeyT, groupname: GroupT, *ids: StreamIdT + ) -> ResponseT[IntegerResponseT]: """ Acknowledges the successful processing of one or more messages. name: name of the stream. @@ -3519,7 +3505,7 @@ def xadd( nomkstream: bool = False, minid: Union[StreamIdT, None] = None, limit: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Add to a stream. name: name of the stream @@ -3535,10 +3521,9 @@ def xadd( For more information see https://redis.io/commands/xadd """ - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") - if maxlen is not None: if not isinstance(maxlen, int) or maxlen < 0: raise DataError("XADD maxlen must be non-negative integer") @@ -3571,7 +3556,7 @@ def xautoclaim( start_id: StreamIdT = "0-0", count: Union[int, None] = None, justid: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -3597,10 +3582,8 @@ def xautoclaim( ) except TypeError: pass - kwargs = {} pieces = [name, groupname, consumername, min_idle_time, start_id] - try: if int(count) < 0: raise DataError("XPENDING count must be a integer >= 0") @@ -3610,7 +3593,6 @@ def xautoclaim( if justid: pieces.append(b"JUSTID") kwargs["parse_justid"] = True - return self.execute_command("XAUTOCLAIM", *pieces, **kwargs) def xclaim( @@ -3625,7 +3607,7 @@ def xclaim( retrycount: Union[int, None] = None, force: bool = False, justid: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Changes the ownership of a pending message. @@ -3667,11 +3649,9 @@ def xclaim( "XCLAIM message_ids must be a non empty list or " "tuple of message IDs to claim" ) - kwargs = {} - pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] + pieces: List[EncodableT] = [name, groupname, consumername, str(min_idle_time)] pieces.extend(list(message_ids)) - if idle is not None: if not isinstance(idle, int): raise DataError("XCLAIM idle must be an integer") @@ -3684,7 +3664,6 @@ def xclaim( if not isinstance(retrycount, int): raise DataError("XCLAIM retrycount must be an integer") pieces.extend((b"RETRYCOUNT", str(retrycount))) - if force: if not isinstance(force, bool): raise DataError("XCLAIM force must be a boolean") @@ -3696,7 +3675,7 @@ def xclaim( kwargs["parse_justid"] = True return self.execute_command("XCLAIM", *pieces, **kwargs) - def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT[IntegerResponseT]: """ Deletes one or more messages from a stream. name: name of the stream. @@ -3713,7 +3692,7 @@ def xgroup_create( id: StreamIdT = "$", mkstream: bool = False, entries_read: Optional[int] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create a new consumer group associated with a stream. name: name of the stream. @@ -3722,17 +3701,16 @@ def xgroup_create( For more information see https://redis.io/commands/xgroup-create """ - pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] + pieces: List[EncodableT] = ["XGROUP CREATE", name, groupname, id] if mkstream: pieces.append(b"MKSTREAM") if entries_read is not None: pieces.extend(["ENTRIESREAD", entries_read]) - return self.execute_command(*pieces) def xgroup_delconsumer( self, name: KeyT, groupname: GroupT, consumername: ConsumerT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Remove a specific consumer from a consumer group. Returns the number of pending messages that the consumer had before it @@ -3745,7 +3723,9 @@ def xgroup_delconsumer( """ return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) - def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xgroup_destroy( + self, name: KeyT, groupname: GroupT + ) -> ResponseT[IntegerResponseT]: """ Destroy a consumer group. name: name of the stream. @@ -3757,7 +3737,7 @@ def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: def xgroup_createconsumer( self, name: KeyT, groupname: GroupT, consumername: ConsumerT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Consumers in a consumer group are auto-created every time a new consumer name is mentioned by some command. @@ -3778,7 +3758,7 @@ def xgroup_setid( groupname: GroupT, id: StreamIdT, entries_read: Optional[int] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Set the consumer group last delivered ID to something else. name: name of the stream. @@ -3792,7 +3772,9 @@ def xgroup_setid( pieces.extend(["ENTRIESREAD", entries_read]) return self.execute_command("XGROUP SETID", *pieces) - def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xinfo_consumers( + self, name: KeyT, groupname: GroupT + ) -> ResponseT[ArrayResponseT]: """ Returns general information about the consumers in the group. name: name of the stream. @@ -3802,7 +3784,7 @@ def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: """ return self.execute_command("XINFO CONSUMERS", name, groupname) - def xinfo_groups(self, name: KeyT) -> ResponseT: + def xinfo_groups(self, name: KeyT) -> ResponseT[ArrayResponseT]: """ Returns general information about the consumer groups of the stream. name: name of the stream. @@ -3811,7 +3793,7 @@ def xinfo_groups(self, name: KeyT) -> ResponseT: """ return self.execute_command("XINFO GROUPS", name) - def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT[ArrayResponseT]: """ Returns general information about the stream. name: name of the stream. @@ -3826,7 +3808,7 @@ def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: options = {"full": full} return self.execute_command("XINFO STREAM", *pieces, **options) - def xlen(self, name: KeyT) -> ResponseT: + def xlen(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of elements in a given stream. @@ -3834,7 +3816,7 @@ def xlen(self, name: KeyT) -> ResponseT: """ return self.execute_command("XLEN", name, keys=[name]) - def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT[ArrayResponseT]: """ Returns information about pending messages of a group. name: name of the stream. @@ -3853,7 +3835,7 @@ def xpending_range( count: int, consumername: Union[ConsumerT, None] = None, idle: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Returns information about pending messages, in a range. @@ -3874,7 +3856,6 @@ def xpending_range( " with min, max and count parameters" ) return self.xpending(name, groupname) - pieces = [name, groupname] if min is None or max is None or count is None: raise DataError( @@ -3898,7 +3879,6 @@ def xpending_range( # consumername if consumername: pieces.append(consumername) - return self.execute_command("XPENDING", *pieces, parse_detail=True) def xrange( @@ -3907,7 +3887,7 @@ def xrange( min: StreamIdT = "-", max: StreamIdT = "+", count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Read stream values within an interval. @@ -3930,7 +3910,6 @@ def xrange( raise DataError("XRANGE count must be a positive integer") pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces, keys=[name]) def xread( @@ -3938,7 +3917,7 @@ def xread( streams: Dict[KeyT, StreamIdT], count: Union[int, None] = None, block: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Block and monitor multiple streams for new data. @@ -3979,7 +3958,7 @@ def xreadgroup( count: Union[int, None] = None, block: Union[int, None] = None, noack: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Read from a stream via a consumer group. @@ -3998,7 +3977,7 @@ def xreadgroup( For more information see https://redis.io/commands/xreadgroup """ - pieces: list[EncodableT] = [b"GROUP", groupname, consumername] + pieces: List[EncodableT] = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") @@ -4024,7 +4003,7 @@ def xrevrange( max: StreamIdT = "+", min: StreamIdT = "-", count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Read stream values within an interval, in reverse order. @@ -4041,13 +4020,12 @@ def xrevrange( For more information see https://redis.io/commands/xrevrange """ - pieces: list[EncodableT] = [max, min] + pieces: List[EncodableT] = [max, min] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREVRANGE count must be a positive integer") pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) def xtrim( @@ -4057,7 +4035,7 @@ def xtrim( approximate: bool = True, minid: Union[StreamIdT, None] = None, limit: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Trims old messages from a stream. name: name of the stream. @@ -4070,13 +4048,11 @@ def xtrim( For more information see https://redis.io/commands/xtrim """ - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ``maxlen`` or ``minid`` may be specified") - if maxlen is None and minid is None: raise DataError("One of ``maxlen`` or ``minid`` must be specified") - if maxlen is not None: pieces.append(b"MAXLEN") if minid is not None: @@ -4090,7 +4066,6 @@ def xtrim( if limit is not None: pieces.append(b"LIMIT") pieces.append(limit) - return self.execute_command("XTRIM", name, *pieces) @@ -4113,7 +4088,7 @@ def zadd( incr: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, BulkStringResponseT, None]]: """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -4160,8 +4135,7 @@ def zadd( ) if nx and (gt or lt): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") - - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] options = {} if nx: pieces.append(b"NX") @@ -4181,7 +4155,7 @@ def zadd( pieces.append(pair[0]) return self.execute_command("ZADD", name, *pieces, **options) - def zcard(self, name: KeyT) -> ResponseT: + def zcard(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Return the number of elements in the sorted set ``name`` @@ -4189,7 +4163,9 @@ def zcard(self, name: KeyT) -> ResponseT: """ return self.execute_command("ZCARD", name, keys=[name]) - def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: + def zcount( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT[IntegerResponseT]: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -4198,7 +4174,7 @@ def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ return self.execute_command("ZCOUNT", name, min, max, keys=[name]) - def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT[ArrayResponseT]: """ Returns the difference between the first and all successive input sorted sets provided in ``keys``. @@ -4210,7 +4186,7 @@ def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: pieces.append("WITHSCORES") return self.execute_command("ZDIFF", *pieces, keys=keys) - def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT[IntegerResponseT]: """ Computes the difference between the first and all successive input sorted sets provided in ``keys`` and stores the result in ``dest``. @@ -4220,7 +4196,9 @@ def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: pieces = [len(keys), *keys] return self.execute_command("ZDIFFSTORE", dest, *pieces) - def zincrby(self, name: KeyT, amount: float, value: EncodableT) -> ResponseT: + def zincrby( + self, name: KeyT, amount: float, value: EncodableT + ) -> ResponseT[BulkStringResponseT]: """ Increment the score of ``value`` in sorted set ``name`` by ``amount`` @@ -4249,7 +4227,7 @@ def zinterstore( dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated @@ -4265,7 +4243,7 @@ def zinterstore( def zintercard( self, numkeys: int, keys: List[str], limit: int = 0 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Return the cardinality of the intersect of multiple sorted sets specified by ``keys`. @@ -4278,7 +4256,7 @@ def zintercard( args = [numkeys, *keys, "LIMIT", limit] return self.execute_command("ZINTERCARD", *args, keys=keys) - def zlexcount(self, name, min, max): + def zlexcount(self, name, min, max) -> ResponseT[IntegerResponseT]: """ Return the number of items in the sorted set ``name`` between the lexicographical range ``min`` and ``max``. @@ -4287,7 +4265,9 @@ def zlexcount(self, name, min, max): """ return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) - def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmax( + self, name: KeyT, count: Union[int, None] = None + ) -> ResponseT[ArrayResponseT]: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -4298,7 +4278,9 @@ def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmin( + self, name: KeyT, count: Union[int, None] = None + ) -> ResponseT[ArrayResponseT]: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -4311,7 +4293,7 @@ def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: def zrandmember( self, key: KeyT, count: int = None, withscores: bool = False - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Return a random element from the sorted set value stored at key. @@ -4332,10 +4314,11 @@ def zrandmember( params.append(count) if withscores: params.append("WITHSCORES") - return self.execute_command("ZRANDMEMBER", key, *params) - def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + def bzpopmax( + self, keys: KeysT, timeout: TimeoutSecT = 0 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -4354,7 +4337,9 @@ def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: keys.append(timeout) return self.execute_command("BZPOPMAX", *keys) - def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + def bzpopmin( + self, keys: KeysT, timeout: TimeoutSecT = 0 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. @@ -4369,7 +4354,7 @@ def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ if timeout is None: timeout = 0 - keys: list[EncodableT] = list_or_args(keys, None) + keys: List[EncodableT] = list_or_args(keys, None) keys.append(timeout) return self.execute_command("BZPOPMIN", *keys) @@ -4380,7 +4365,7 @@ def zmpop( min: Optional[bool] = False, max: Optional[bool] = False, count: Optional[int] = 1, - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) off of the first non-empty sorted set named in the ``keys`` list. @@ -4395,7 +4380,6 @@ def zmpop( args.append("MAX") if count != 1: args.extend(["COUNT", count]) - return self.execute_command("ZMPOP", *args) def bzmpop( @@ -4406,7 +4390,7 @@ def bzmpop( min: Optional[bool] = False, max: Optional[bool] = False, count: Optional[int] = 1, - ) -> Optional[list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) off of the first non-empty sorted set named in the ``keys`` list. @@ -4427,7 +4411,6 @@ def bzmpop( else: args.append("MAX") args.extend(["COUNT", count]) - return self.execute_command("BZMPOP", *args) def _zrange( @@ -4444,7 +4427,7 @@ def _zrange( score_cast_func: Union[type, Callable, None] = float, offset: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Any]: if byscore and bylex: raise DataError("``byscore`` and ``bylex`` can not be specified together.") if (offset is not None and num is None) or (num is not None and offset is None): @@ -4481,9 +4464,9 @@ def zrange( score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, - offset: int = None, - num: int = None, - ) -> ResponseT: + offset: Optional[int] = None, + num: Optional[int] = None, + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -4516,7 +4499,6 @@ def zrange( # because it was supported in 3.5.3 (of redis-py) if not byscore and not bylex and (offset is None and num is None) and desc: return self.zrevrange(name, start, end, withscores, score_cast_func) - return self._zrange( "ZRANGE", None, @@ -4539,7 +4521,7 @@ def zrevrange( end: int, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -4571,7 +4553,7 @@ def zrangestore( desc: bool = False, offset: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -4617,7 +4599,7 @@ def zrangebylex( max: EncodableT, start: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the lexicographical range of values from sorted set ``name`` between ``min`` and ``max``. @@ -4641,7 +4623,7 @@ def zrevrangebylex( min: EncodableT, start: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the reversed lexicographical range of values from sorted set ``name`` between ``max`` and ``min``. @@ -4667,7 +4649,7 @@ def zrangebyscore( num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -4702,7 +4684,7 @@ def zrevrangebyscore( num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ): + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max`` in descending order. @@ -4729,11 +4711,8 @@ def zrevrangebyscore( return self.execute_command(*pieces, **options) def zrank( - self, - name: KeyT, - value: EncodableT, - withscore: bool = False, - ) -> ResponseT: + self, name: KeyT, value: EncodableT, withscore: bool = False + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Returns a 0-based value indicating the rank of ``value`` in sorted set ``name``. @@ -4746,7 +4725,7 @@ def zrank( return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) return self.execute_command("ZRANK", name, value, keys=[name]) - def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: + def zrem(self, name: KeyT, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Remove member ``values`` from sorted set ``name`` @@ -4754,7 +4733,9 @@ def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("ZREM", name, *values) - def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: + def zremrangebylex( + self, name: KeyT, min: EncodableT, max: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -4765,7 +4746,9 @@ def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> Respon """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: + def zremrangebyrank( + self, name: KeyT, min: int, max: int + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -4778,7 +4761,7 @@ def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: def zremrangebyscore( self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -4788,11 +4771,8 @@ def zremrangebyscore( return self.execute_command("ZREMRANGEBYSCORE", name, min, max) def zrevrank( - self, - name: KeyT, - value: EncodableT, - withscore: bool = False, - ) -> ResponseT: + self, name: KeyT, value: EncodableT, withscore: bool = False + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Returns a 0-based value indicating the descending rank of ``value`` in sorted set ``name``. @@ -4807,7 +4787,9 @@ def zrevrank( ) return self.execute_command("ZREVRANK", name, value, keys=[name]) - def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: + def zscore( + self, name: KeyT, value: EncodableT + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the score of element ``value`` in sorted set ``name`` @@ -4820,7 +4802,7 @@ def zunion( keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, withscores: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the union of multiple sorted sets specified by ``keys``. ``keys`` can be provided as dictionary of keys and their weights. @@ -4836,7 +4818,7 @@ def zunionstore( dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -4846,7 +4828,9 @@ def zunionstore( """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) - def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: + def zmscore( + self, key: KeyT, members: List[str] + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Returns the scores associated with the specified members in the sorted set stored at key. @@ -4869,8 +4853,8 @@ def _zaggregate( keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, **options, - ) -> ResponseT: - pieces: list[EncodableT] = [command] + ) -> ResponseT[Any]: + pieces: List[EncodableT] = [command] if dest is not None: pieces.append(dest) pieces.append(len(keys)) @@ -4903,7 +4887,7 @@ class HyperlogCommands(CommandsProtocol): see: https://redis.io/topics/data-types-intro#hyperloglogs """ - def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT: + def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Adds the specified elements to the specified HyperLogLog. @@ -4911,7 +4895,7 @@ def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("PFADD", name, *values) - def pfcount(self, *sources: KeyT) -> ResponseT: + def pfcount(self, *sources: KeyT) -> ResponseT[IntegerResponseT]: """ Return the approximated cardinality of the set observed by the HyperLogLog at key(s). @@ -4920,7 +4904,7 @@ def pfcount(self, *sources: KeyT) -> ResponseT: """ return self.execute_command("PFCOUNT", *sources) - def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT[OKT]: """ Merge N different HyperLogLogs into a single one. @@ -4938,7 +4922,7 @@ class HashCommands(CommandsProtocol): see: https://redis.io/topics/data-types-intro#redis-hashes """ - def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: + def hdel(self, name: str, *keys: str) -> ResponseT[IntegerResponseT]: """ Delete ``keys`` from hash ``name`` @@ -4946,7 +4930,7 @@ def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: """ return self.execute_command("HDEL", name, *keys) - def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: + def hexists(self, name: str, key: str) -> ResponseT[IntegerResponseT]: """ Returns a boolean indicating if ``key`` exists within hash ``name`` @@ -4954,9 +4938,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("HEXISTS", name, key, keys=[name]) - def hget( - self, name: str, key: str - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + def hget(self, name: str, key: str) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the value of ``key`` within the hash ``name`` @@ -4964,7 +4946,7 @@ def hget( """ return self.execute_command("HGET", name, key, keys=[name]) - def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: + def hgetall(self, name: str) -> ResponseT[ArrayResponseT]: """ Return a Python dict of the hash's name/value pairs @@ -4974,7 +4956,7 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: def hincrby( self, name: str, key: str, amount: int = 1 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Increment the value of ``key`` in hash ``name`` by ``amount`` @@ -4984,7 +4966,7 @@ def hincrby( def hincrbyfloat( self, name: str, key: str, amount: float = 1.0 - ) -> Union[Awaitable[float], float]: + ) -> ResponseT[BulkStringResponseT]: """ Increment the value of ``key`` in hash ``name`` by floating ``amount`` @@ -4992,7 +4974,7 @@ def hincrbyfloat( """ return self.execute_command("HINCRBYFLOAT", name, key, amount) - def hkeys(self, name: str) -> Union[Awaitable[List], List]: + def hkeys(self, name: str) -> ResponseT[ArrayResponseT]: """ Return the list of keys within hash ``name`` @@ -5000,7 +4982,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HKEYS", name, keys=[name]) - def hlen(self, name: str) -> Union[Awaitable[int], int]: + def hlen(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the number of elements in hash ``name`` @@ -5014,8 +4996,8 @@ def hset( key: Optional[str] = None, value: Optional[str] = None, mapping: Optional[dict] = None, - items: Optional[list] = None, - ) -> Union[Awaitable[int], int]: + items: Optional[List] = None, + ) -> ResponseT[IntegerResponseT]: """ Set ``key`` to ``value`` within hash ``name``, ``mapping`` accepts a dict of key/value pairs that will be @@ -5036,10 +5018,9 @@ def hset( if mapping: for pair in mapping.items(): pieces.extend(pair) - return self.execute_command("HSET", name, *pieces) - def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: + def hsetnx(self, name: str, key: str, value: str) -> ResponseT[IntegerResponseT]: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not exist. Returns 1 if HSETNX created a field, otherwise 0. @@ -5048,7 +5029,7 @@ def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool """ return self.execute_command("HSETNX", name, key, value) - def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: + def hmset(self, name: str, mapping: dict) -> ResponseT[OKT]: """ Set key to value within hash ``name`` for each corresponding key and value from the ``mapping`` dict. @@ -5068,7 +5049,7 @@ def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: items.extend(pair) return self.execute_command("HMSET", name, *items) - def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], List]: + def hmget(self, name: str, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Returns a list of values ordered identically to ``keys`` @@ -5077,7 +5058,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li args = list_or_args(keys, args) return self.execute_command("HMGET", name, *args, keys=[name]) - def hvals(self, name: str) -> Union[Awaitable[List], List]: + def hvals(self, name: str) -> ResponseT[ArrayResponseT]: """ Return the list of values within hash ``name`` @@ -5085,7 +5066,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HVALS", name, keys=[name]) - def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: + def hstrlen(self, name: str, key: str) -> ResponseT[IntegerResponseT]: """ Return the number of bytes stored in the value of ``key`` within hash ``name`` @@ -5103,7 +5084,7 @@ def hexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using relative time in seconds. @@ -5137,10 +5118,8 @@ def hexpire( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(seconds, datetime.timedelta): seconds = int(seconds.total_seconds()) - options = [] if nx: options.append("NX") @@ -5150,7 +5129,6 @@ def hexpire( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HEXPIRE", name, seconds, *options, "FIELDS", len(fields), *fields ) @@ -5164,7 +5142,7 @@ def hpexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using relative time in milliseconds. @@ -5198,10 +5176,8 @@ def hpexpire( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(milliseconds, datetime.timedelta): milliseconds = int(milliseconds.total_seconds() * 1000) - options = [] if nx: options.append("NX") @@ -5211,7 +5187,6 @@ def hpexpire( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HPEXPIRE", name, milliseconds, *options, "FIELDS", len(fields), *fields ) @@ -5225,7 +5200,7 @@ def hexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using an absolute Unix timestamp in seconds. @@ -5259,10 +5234,8 @@ def hexpireat( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(unix_time_seconds, datetime.datetime): unix_time_seconds = int(unix_time_seconds.timestamp()) - options = [] if nx: options.append("NX") @@ -5272,7 +5245,6 @@ def hexpireat( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HEXPIREAT", name, @@ -5292,7 +5264,7 @@ def hpexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using an absolute Unix timestamp in milliseconds. @@ -5326,10 +5298,8 @@ def hpexpireat( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(unix_time_milliseconds, datetime.datetime): unix_time_milliseconds = int(unix_time_milliseconds.timestamp() * 1000) - options = [] if nx: options.append("NX") @@ -5339,7 +5309,6 @@ def hpexpireat( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HPEXPIREAT", name, @@ -5350,7 +5319,9 @@ def hpexpireat( *fields, ) - def hpersist(self, name: KeyT, *fields: str) -> ResponseT: + def hpersist( + self, name: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Removes the expiration time for each specified field in a hash. @@ -5370,7 +5341,9 @@ def hpersist(self, name: KeyT, *fields: str) -> ResponseT: """ return self.execute_command("HPERSIST", name, "FIELDS", len(fields), *fields) - def hexpiretime(self, key: KeyT, *fields: str) -> ResponseT: + def hexpiretime( + self, key: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Returns the expiration times of hash fields as Unix timestamps in seconds. @@ -5393,7 +5366,9 @@ def hexpiretime(self, key: KeyT, *fields: str) -> ResponseT: "HEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] ) - def hpexpiretime(self, key: KeyT, *fields: str) -> ResponseT: + def hpexpiretime( + self, key: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Returns the expiration times of hash fields as Unix timestamps in milliseconds. @@ -5416,7 +5391,7 @@ def hpexpiretime(self, key: KeyT, *fields: str) -> ResponseT: "HPEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] ) - def httl(self, key: KeyT, *fields: str) -> ResponseT: + def httl(self, key: KeyT, *fields: str) -> ResponseT[ArrayResponseT]: """ Returns the TTL (Time To Live) in seconds for each specified field within a hash key. @@ -5439,7 +5414,7 @@ def httl(self, key: KeyT, *fields: str) -> ResponseT: "HTTL", key, "FIELDS", len(fields), *fields, keys=[key] ) - def hpttl(self, key: KeyT, *fields: str) -> ResponseT: + def hpttl(self, key: KeyT, *fields: str) -> ResponseT[ArrayResponseT]: """ Returns the TTL (Time To Live) in milliseconds for each specified field within a hash key. @@ -5570,7 +5545,9 @@ class PubSubCommands(CommandsProtocol): see https://redis.io/topics/pubsub """ - def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: + def publish( + self, channel: ChannelT, message: EncodableT, **kwargs + ) -> ResponseT[IntegerResponseT]: """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. @@ -5579,7 +5556,9 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) - def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + def spublish( + self, shard_channel: ChannelT, message: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Posts a message to the given shard channel. Returns the number of clients that received the message @@ -5588,7 +5567,9 @@ def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: """ return self.execute_command("SPUBLISH", shard_channel, message) - def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def pubsub_channels( + self, pattern: PatternT = "*", **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of channels that have at least one subscriber @@ -5596,7 +5577,9 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) - def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def pubsub_shardchannels( + self, pattern: PatternT = "*", **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of shard_channels that have at least one subscriber @@ -5604,7 +5587,7 @@ def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) - def pubsub_numpat(self, **kwargs) -> ResponseT: + def pubsub_numpat(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the number of subscriptions to patterns @@ -5612,7 +5595,7 @@ def pubsub_numpat(self, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMPAT", **kwargs) - def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` @@ -5621,7 +5604,9 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) - def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + def pubsub_shardnumsub( + self, *args: ChannelT, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of (shard_channel, number of subscribers) tuples for each channel given in ``*args`` @@ -5642,12 +5627,10 @@ class ScriptCommands(CommandsProtocol): def _eval( self, command: str, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[Any]: return self.execute_command(command, script, numkeys, *keys_and_args) - def eval( - self, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def eval(self, script: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ Execute the Lua ``script``, specifying the ``numkeys`` the script will touch and the key names and argument values in ``keys_and_args``. @@ -5660,9 +5643,7 @@ def eval( """ return self._eval("EVAL", script, numkeys, *keys_and_args) - def eval_ro( - self, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def eval_ro(self, script: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ The read-only variant of the EVAL command @@ -5675,13 +5656,11 @@ def eval_ro( return self._eval("EVAL_RO", script, numkeys, *keys_and_args) def _evalsha( - self, command: str, sha: str, numkeys: int, *keys_and_args: list - ) -> Union[Awaitable[str], str]: + self, command: str, sha: str, numkeys: int, *keys_and_args: List + ) -> ResponseT[Any]: return self.execute_command(command, sha, numkeys, *keys_and_args) - def evalsha( - self, sha: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def evalsha(self, sha: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ Use the ``sha`` to execute a Lua script already registered via EVAL or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the @@ -5695,9 +5674,7 @@ def evalsha( """ return self._evalsha("EVALSHA", sha, numkeys, *keys_and_args) - def evalsha_ro( - self, sha: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def evalsha_ro(self, sha: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ The read-only variant of the EVALSHA command @@ -5710,7 +5687,7 @@ def evalsha_ro( """ return self._evalsha("EVALSHA_RO", sha, numkeys, *keys_and_args) - def script_exists(self, *args: str) -> ResponseT: + def script_exists(self, *args: str) -> ResponseT[ArrayResponseT]: """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if @@ -5727,7 +5704,7 @@ def script_debug(self, *args) -> None: def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None - ) -> ResponseT: + ) -> ResponseT[OKT]: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be @@ -5749,7 +5726,7 @@ def script_flush( pieces = [sync_type] return self.execute_command("SCRIPT FLUSH", *pieces) - def script_kill(self) -> ResponseT: + def script_kill(self) -> ResponseT[OKT]: """ Kill the currently executing Lua script @@ -5757,7 +5734,7 @@ def script_kill(self) -> ResponseT: """ return self.execute_command("SCRIPT KILL") - def script_load(self, script: ScriptTextT) -> ResponseT: + def script_load(self, script: ScriptTextT) -> ResponseT[BulkStringResponseT]: """ Load a Lua ``script`` into the script cache. Returns the SHA. @@ -5776,6 +5753,7 @@ def register_script(self: "Redis", script: ScriptTextT) -> Script: class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: return super().script_debug() @@ -5802,7 +5780,7 @@ def geoadd( nx: bool = False, xx: bool = False, ch: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered @@ -5839,7 +5817,7 @@ def geoadd( def geodist( self, name: KeyT, place1: FieldT, place2: FieldT, unit: Union[str, None] = None - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the distance between ``place1`` and ``place2`` members of the ``name`` key. @@ -5848,14 +5826,14 @@ def geodist( For more information see https://redis.io/commands/geodist """ - pieces: list[EncodableT] = [name, place1, place2] + pieces: List[EncodableT] = [name, place1, place2] if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command("GEODIST", *pieces, keys=[name]) - def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT[ArrayResponseT]: """ Return the geo hash string for each item of ``values`` members of the specified key identified by the ``name`` argument. @@ -5864,7 +5842,7 @@ def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("GEOHASH", name, *values, keys=[name]) - def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT[ArrayResponseT]: """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name`` argument. Each position @@ -5978,7 +5956,7 @@ def georadiusbymember( def _georadiusgeneric( self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] - ) -> ResponseT: + ) -> ResponseT[Any]: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") @@ -5986,10 +5964,8 @@ def _georadiusgeneric( pieces.append(kwargs["unit"]) else: pieces.append("m") - if kwargs["any"] and kwargs["count"] is None: raise DataError("``any`` can't be provided without ``count``") - for arg_name, byte_repr in ( ("withdist", "WITHDIST"), ("withcoord", "WITHCOORD"), @@ -5997,12 +5973,10 @@ def _georadiusgeneric( ): if kwargs[arg_name]: pieces.append(byte_repr) - if kwargs["count"] is not None: pieces.extend(["COUNT", kwargs["count"]]) if kwargs["any"]: pieces.append("ANY") - if kwargs["sort"]: if kwargs["sort"] == "ASC": pieces.append("ASC") @@ -6010,16 +5984,12 @@ def _georadiusgeneric( pieces.append("DESC") else: raise DataError("GEORADIUS invalid sort") - if kwargs["store"] and kwargs["store_dist"]: raise DataError("GEORADIUS store and store_dist cant be set together") - if kwargs["store"]: pieces.extend([b"STORE", kwargs["store"]]) - if kwargs["store_dist"]: pieces.extend([b"STOREDIST", kwargs["store_dist"]]) - return self.execute_command(command, *pieces, **kwargs) def geosearch( @@ -6082,7 +6052,6 @@ def geosearch( For more information see https://redis.io/commands/geosearch """ - return self._geosearchgeneric( "GEOSEARCH", name, @@ -6152,7 +6121,7 @@ def geosearchstore( def _geosearchgeneric( self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] - ) -> ResponseT: + ) -> ResponseT[Any]: pieces = list(args) # FROMMEMBER or FROMLONLAT @@ -6211,9 +6180,7 @@ def _geosearchgeneric( ): if kwargs[arg_name]: pieces.append(byte_repr) - kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] - return self.execute_command(command, *pieces, **kwargs) @@ -6226,7 +6193,7 @@ class ModuleCommands(CommandsProtocol): see: https://redis.io/topics/modules-intro """ - def module_load(self, path, *args) -> ResponseT: + def module_load(self, path, *args) -> ResponseT[OKT]: """ Loads the module from ``path``. Passes all ``*args`` to the module, during loading. @@ -6241,7 +6208,7 @@ def module_loadex( path: str, options: Optional[List[str]] = None, args: Optional[List[str]] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Loads a module from a dynamic library at runtime with configuration directives. @@ -6254,10 +6221,9 @@ def module_loadex( if args is not None: pieces.append("ARGS") pieces.extend(args) - return self.execute_command("MODULE LOADEX", path, *pieces) - def module_unload(self, name) -> ResponseT: + def module_unload(self, name) -> ResponseT[OKT]: """ Unloads the module ``name``. Raises ``ModuleError`` if ``name`` is not in loaded modules. @@ -6266,7 +6232,7 @@ def module_unload(self, name) -> ResponseT: """ return self.execute_command("MODULE UNLOAD", name) - def module_list(self) -> ResponseT: + def module_list(self) -> ResponseT[ArrayResponseT]: """ Returns a list of dictionaries containing the name and version of all loaded modules. @@ -6280,13 +6246,13 @@ def command_info(self) -> None: "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self) -> ResponseT: + def command_count(self) -> ResponseT[IntegerResponseT]: return self.execute_command("COMMAND COUNT") - def command_getkeys(self, *args) -> ResponseT: + def command_getkeys(self, *args) -> ResponseT[ArrayResponseT]: return self.execute_command("COMMAND GETKEYS", *args) - def command(self) -> ResponseT: + def command(self) -> ResponseT[ArrayResponseT]: return self.execute_command("COMMAND") @@ -6347,6 +6313,7 @@ def get_encoder(self): class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: return super().command_info() @@ -6356,10 +6323,10 @@ class ClusterCommands(CommandsProtocol): Class for Redis Cluster commands """ - def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT[Any]: return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) - def readwrite(self, **kwargs) -> ResponseT: + def readwrite(self, **kwargs) -> ResponseT[OKT]: """ Disables read queries for a connection to a Redis Cluster slave node. @@ -6367,7 +6334,7 @@ def readwrite(self, **kwargs) -> ResponseT: """ return self.execute_command("READWRITE", **kwargs) - def readonly(self, **kwargs) -> ResponseT: + def readonly(self, **kwargs) -> ResponseT[OKT]: """ Enables read queries for a connection to a Redis Cluster replica node. @@ -6379,14 +6346,14 @@ def readonly(self, **kwargs) -> ResponseT: AsyncClusterCommands = ClusterCommands -class FunctionCommands: +class FunctionCommands(CommandsProtocol): """ Redis Function commands """ def function_load( self, code: str, replace: Optional[bool] = False - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[BulkStringResponseT]: """ Load a library to Redis. :param code: the source code (must start with @@ -6401,7 +6368,7 @@ def function_load( pieces.append(code) return self.execute_command("FUNCTION LOAD", *pieces) - def function_delete(self, library: str) -> Union[Awaitable[str], str]: + def function_delete(self, library: str) -> ResponseT[OKT]: """ Delete the library called ``library`` and all its functions. @@ -6409,7 +6376,7 @@ def function_delete(self, library: str) -> Union[Awaitable[str], str]: """ return self.execute_command("FUNCTION DELETE", library) - def function_flush(self, mode: str = "SYNC") -> Union[Awaitable[str], str]: + def function_flush(self, mode: str = "SYNC") -> ResponseT[OKT]: """ Deletes all the libraries. @@ -6419,7 +6386,7 @@ def function_flush(self, mode: str = "SYNC") -> Union[Awaitable[str], str]: def function_list( self, library: Optional[str] = "*", withcode: Optional[bool] = False - ) -> Union[Awaitable[List], List]: + ) -> ResponseT[ArrayResponseT]: """ Return information about the functions and libraries. :param library: pecify a pattern for matching library names @@ -6433,12 +6400,12 @@ def function_list( def _fcall( self, command: str, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[Any]: return self.execute_command(command, function, numkeys, *keys_and_args) def fcall( self, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[str]: """ Invoke a function. @@ -6448,7 +6415,7 @@ def fcall( def fcall_ro( self, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[str]: """ This is a read-only variant of the FCALL command that cannot execute commands that modify data. @@ -6457,7 +6424,7 @@ def fcall_ro( """ return self._fcall("FCALL_RO", function, numkeys, *keys_and_args) - def function_dump(self) -> Union[Awaitable[str], str]: + def function_dump(self) -> ResponseT[BulkStringResponseT]: """ Return the serialized payload of loaded libraries. @@ -6467,12 +6434,11 @@ def function_dump(self) -> Union[Awaitable[str], str]: options = {} options[NEVER_DECODE] = [] - return self.execute_command("FUNCTION DUMP", **options) def function_restore( self, payload: str, policy: Optional[str] = "APPEND" - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[OKT]: """ Restore libraries from the serialized ``payload``. You can use the optional policy argument to provide a policy @@ -6482,7 +6448,7 @@ def function_restore( """ return self.execute_command("FUNCTION RESTORE", payload, policy) - def function_kill(self) -> Union[Awaitable[str], str]: + def function_kill(self) -> ResponseT[OKT]: """ Kill a function that is currently executing. @@ -6490,7 +6456,7 @@ def function_kill(self) -> Union[Awaitable[str], str]: """ return self.execute_command("FUNCTION KILL") - def function_stats(self) -> Union[Awaitable[List], List]: + def function_stats(self) -> ResponseT[ArrayResponseT]: """ Return information about the function that's currently running and information about the available execution engines. @@ -6503,7 +6469,8 @@ def function_stats(self) -> Union[Awaitable[List], List]: AsyncFunctionCommands = FunctionCommands -class GearsCommands: +class GearsCommands(CommandsProtocol): + def tfunction_load( self, lib_code: str, replace: bool = False, config: Union[str, None] = None ) -> ResponseT: @@ -6563,14 +6530,13 @@ def tfunction_list( if lib_name is not None: pieces.append("LIBRARY") pieces.append(lib_name) - return self.execute_command("TFUNCTION LIST", *pieces) def _tfcall( self, lib_name: str, func_name: str, - keys: KeysT = None, + keys: Optional[KeysT] = None, _async: bool = False, *args: List, ) -> ResponseT: @@ -6587,11 +6553,7 @@ def _tfcall( return self.execute_command("TFCALL", *pieces) def tfcall( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, + self, lib_name: str, func_name: str, keys: Optional[KeysT] = None, *args: List ) -> ResponseT: """ Invoke a function. @@ -6606,11 +6568,7 @@ def tfcall( return self._tfcall(lib_name, func_name, keys, False, *args) def tfcall_async( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, + self, lib_name: str, func_name: str, keys: Optional[KeysT] = None, *args: List ) -> ResponseT: """ Invoke an async function (coroutine). diff --git a/redis/typing.py b/redis/typing.py index ee4296245..13c679e75 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,11 +1,11 @@ -# from __future__ import annotations - from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, Awaitable, Iterable, + List, + Literal, Mapping, Protocol, Type, @@ -32,7 +32,14 @@ PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] -ResponseT = Union[Awaitable[Any], Any] +OldResponseT = Union[Awaitable[Any], Any] # Deprecated +AnyResponseT = TypeVar("AnyResponseT", bound=Any) +ResponseT = Union[AnyResponseT, Awaitable[AnyResponseT]] +OKT = Literal[True] +ArrayResponseT = List +IntegerResponseT = int +NullResponseT = type(None) +BulkStringResponseT = str ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name @@ -54,10 +61,10 @@ class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] - def execute_command(self, *args, **options): ... + def execute_command(self, *args, **options) -> ResponseT[Any]: ... class ClusterCommandsProtocol(CommandsProtocol, Protocol): encoder: "Encoder" - def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... + def execute_command(self, *args, **options) -> ResponseT[Any]: ...