diff --git a/stubs/redis/redis/asyncio/client.pyi b/stubs/redis/redis/asyncio/client.pyi index 2457d315f6f7..df616e6be01e 100644 --- a/stubs/redis/redis/asyncio/client.pyi +++ b/stubs/redis/redis/asyncio/client.pyi @@ -44,7 +44,7 @@ class Redis(AbstractRedis, RedisModuleCommands, AsyncCoreCommands[_StrType], Asy socket_connect_timeout: float | None = None, socket_keepalive: bool | None = None, socket_keepalive_options: Mapping[int, int | bytes] | None = None, - connection_pool: ConnectionPool | None = None, + connection_pool: ConnectionPool[Any] | None = None, unix_socket_path: str | None = None, encoding: str = "utf-8", encoding_errors: str = "strict", @@ -82,7 +82,7 @@ class Redis(AbstractRedis, RedisModuleCommands, AsyncCoreCommands[_StrType], Asy socket_connect_timeout: float | None = None, socket_keepalive: bool | None = None, socket_keepalive_options: Mapping[int, int | bytes] | None = None, - connection_pool: ConnectionPool | None = None, + connection_pool: ConnectionPool[Any] | None = None, unix_socket_path: str | None = None, encoding: str = "utf-8", encoding_errors: str = "strict", @@ -118,7 +118,7 @@ class Redis(AbstractRedis, RedisModuleCommands, AsyncCoreCommands[_StrType], Asy socket_connect_timeout: float | None = None, socket_keepalive: bool | None = None, socket_keepalive_options: Mapping[int, int | bytes] | None = None, - connection_pool: ConnectionPool | None = None, + connection_pool: ConnectionPool[Any] | None = None, unix_socket_path: str | None = None, encoding: str = "utf-8", encoding_errors: str = "strict", @@ -154,7 +154,7 @@ class Redis(AbstractRedis, RedisModuleCommands, AsyncCoreCommands[_StrType], Asy socket_connect_timeout: float | None = None, socket_keepalive: bool | None = None, socket_keepalive_options: Mapping[int, int | bytes] | None = None, - connection_pool: ConnectionPool | None = None, + connection_pool: ConnectionPool[Any] | None = None, unix_socket_path: str | None = None, encoding: str = "utf-8", encoding_errors: str = "strict", @@ -228,7 +228,7 @@ class Monitor: command_re: Any connection_pool: Any connection: Any - def __init__(self, connection_pool: ConnectionPool) -> None: ... + def __init__(self, connection_pool: ConnectionPool[Any]) -> None: ... async def connect(self) -> None: ... async def __aenter__(self) -> Self: ... async def __aexit__(self, *args: Unused) -> None: ... @@ -251,7 +251,7 @@ class PubSub: pending_unsubscribe_patterns: Any def __init__( self, - connection_pool: ConnectionPool, + connection_pool: ConnectionPool[Any], shard_hint: str | None = None, ignore_subscribe_messages: bool = False, encoder: Incomplete | None = None, @@ -302,7 +302,7 @@ class Pipeline(Redis[_StrType]): explicit_transaction: bool def __init__( self, - connection_pool: ConnectionPool, + connection_pool: ConnectionPool[Any], response_callbacks: MutableMapping[str | bytes, ResponseCallbackT], transaction: bool, shard_hint: str | None, diff --git a/stubs/redis/redis/asyncio/cluster.pyi b/stubs/redis/redis/asyncio/cluster.pyi index 25d8f9dd143f..2ac5263d45f9 100644 --- a/stubs/redis/redis/asyncio/cluster.pyi +++ b/stubs/redis/redis/asyncio/cluster.pyi @@ -1,11 +1,11 @@ from _typeshed import Incomplete from collections.abc import Awaitable, Callable, Mapping from types import TracebackType -from typing import Any, Generic +from typing import Any, Generic, TypeVar from typing_extensions import Self from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import BaseParser, Connection, Encoder +from redis.asyncio.connection import AbstractConnection, BaseParser, Connection, Encoder from redis.asyncio.parser import CommandsParser from redis.client import AbstractRedis from redis.cluster import AbstractRedisCluster, LoadBalancer @@ -14,13 +14,65 @@ from redis.cluster import AbstractRedisCluster, LoadBalancer # from redis.commands import AsyncRedisClusterCommands from redis.commands.core import _StrType from redis.credentials import CredentialProvider +from redis.exceptions import ResponseError from redis.retry import Retry from redis.typing import AnyKeyT, EncodableT, KeyT +TargetNodesT = TypeVar("TargetNodesT", str, ClusterNode, list[ClusterNode], dict[Any, ClusterNode]) # noqa: Y001 + # It uses `DefaultParser` in real life, but it is a dynamic base class. -class ClusterParser(BaseParser): ... +class ClusterParser(BaseParser): + def on_disconnect(self) -> None: ... + def on_connect(self, connection: AbstractConnection) -> None: ... + async def can_read_destructive(self) -> bool: ... + async def read_response(self, disable_decoding: bool = False) -> EncodableT | ResponseError | list[EncodableT] | None: ... class RedisCluster(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): # TODO: AsyncRedisClusterCommands + @classmethod + def from_url( + cls, + url: str, + *, + host: str | None = None, + port: str | int = 6379, + # Cluster related kwargs + startup_nodes: list[ClusterNode] | None = None, + require_full_coverage: bool = True, + read_from_replicas: bool = False, + reinitialize_steps: int = 5, + cluster_error_retry_attempts: int = 3, + connection_error_retry_attempts: int = 3, + max_connections: int = 2147483648, + # Client related kwargs + db: str | int = 0, + path: str | None = None, + credential_provider: CredentialProvider | None = None, + username: str | None = None, + password: str | None = None, + client_name: str | None = None, + # Encoding related kwargs + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + # Connection related kwargs + health_check_interval: float = 0, + socket_connect_timeout: float | None = None, + socket_keepalive: bool = False, + socket_keepalive_options: Mapping[int, int | bytes] | None = None, + socket_timeout: float | None = None, + retry: Retry | None = None, + retry_on_error: list[Exception] | None = None, + # SSL related kwargs + ssl: bool = False, + ssl_ca_certs: str | None = None, + ssl_ca_data: str | None = None, + ssl_cert_reqs: str = "required", + ssl_certfile: str | None = None, + ssl_check_hostname: bool = False, + ssl_keyfile: str | None = None, + address_remap: Callable[[str, int], tuple[str, int]] | None = None, + ) -> Self: ... + retry: Retry | None connection_kwargs: dict[str, Any] nodes_manager: NodesManager @@ -34,6 +86,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): # T command_flags: dict[str, str] response_callbacks: Incomplete result_callbacks: dict[str, Callable[[Incomplete, Incomplete], Incomplete]] + def __init__( self, host: str | None = None, @@ -98,8 +151,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): # T def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: ... async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: ... def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None) -> ClusterPipeline[_StrType]: ... - @classmethod - def from_url(cls, url: str, **kwargs) -> Self: ... class ClusterNode: host: str diff --git a/stubs/redis/redis/asyncio/connection.pyi b/stubs/redis/redis/asyncio/connection.pyi index 988cb36110da..a20696a4ba76 100644 --- a/stubs/redis/redis/asyncio/connection.pyi +++ b/stubs/redis/redis/asyncio/connection.pyi @@ -1,38 +1,42 @@ import asyncio import enum import ssl -from _typeshed import Incomplete +from _typeshed import Unused +from abc import abstractmethod from collections.abc import Callable, Iterable, Mapping -from typing import Any, Literal, Protocol, TypedDict, overload -from typing_extensions import TypeAlias +from types import MappingProxyType +from typing import Any, Final, Generic, Literal, Protocol, TypedDict, TypeVar, overload +from typing_extensions import Self, TypeAlias -from redis import RedisError from redis.asyncio.retry import Retry from redis.credentials import CredentialProvider -from redis.exceptions import ResponseError +from redis.exceptions import AuthenticationError, RedisError, ResponseError from redis.typing import EncodableT, EncodedT -hiredis: Any -SYM_STAR: bytes -SYM_DOLLAR: bytes -SYM_CRLF: bytes -SYM_LF: bytes -SYM_EMPTY: bytes -SERVER_CLOSED_CONNECTION_ERROR: str +_SSLVerifyMode: TypeAlias = Literal["none", "optional", "required"] + +SYM_STAR: Final[bytes] +SYM_DOLLAR: Final[bytes] +SYM_CRLF: Final[bytes] +SYM_LF: Final[bytes] +SYM_EMPTY: Final[bytes] + +SERVER_CLOSED_CONNECTION_ERROR: Final[str] class _Sentinel(enum.Enum): - sentinel: Any + sentinel: object -SENTINEL: Any -MODULE_LOAD_ERROR: str -NO_SUCH_MODULE_ERROR: str -MODULE_UNLOAD_NOT_POSSIBLE_ERROR: str -MODULE_EXPORTS_DATA_TYPES_ERROR: str +SENTINEL: Final[object] +MODULE_LOAD_ERROR: Final[str] +NO_SUCH_MODULE_ERROR: Final[str] +MODULE_UNLOAD_NOT_POSSIBLE_ERROR: Final[str] +MODULE_EXPORTS_DATA_TYPES_ERROR: Final[str] +NO_AUTH_SET_ERROR: Final[dict[str, type[AuthenticationError]]] class Encoder: - encoding: Any - encoding_errors: Any - decode_responses: Any + encoding: str + encoding_errors: str + decode_responses: bool def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool) -> None: ... def encode(self, value: EncodableT) -> EncodedT: ... def decode(self, value: EncodableT, force: bool = False) -> EncodableT: ... @@ -44,21 +48,28 @@ class BaseParser: def __init__(self, socket_read_size: int) -> None: ... @classmethod def parse_error(cls, response: str) -> ResponseError: ... + @abstractmethod def on_disconnect(self) -> None: ... - def on_connect(self, connection: Connection): ... + @abstractmethod + def on_connect(self, connection: AbstractConnection) -> None: ... + @abstractmethod + async def can_read_destructive(self) -> bool: ... + @abstractmethod async def read_response(self, disable_decoding: bool = False) -> EncodableT | ResponseError | list[EncodableT] | None: ... class PythonParser(BaseParser): - encoder: Any + encoder: Encoder | None def __init__(self, socket_read_size: int) -> None: ... - def on_connect(self, connection: Connection): ... + def on_connect(self, connection: AbstractConnection) -> None: ... def on_disconnect(self) -> None: ... + async def can_read_destructive(self) -> bool: ... async def read_response(self, disable_decoding: bool = False) -> EncodableT | ResponseError | None: ... class HiredisParser(BaseParser): def __init__(self, socket_read_size: int) -> None: ... - def on_connect(self, connection: Connection): ... + def on_connect(self, connection: AbstractConnection) -> None: ... def on_disconnect(self) -> None: ... + async def can_read_destructive(self) -> bool: ... async def read_from_socket(self) -> Literal[True]: ... async def read_response(self, disable_decoding: bool = False) -> EncodableT | list[EncodableT]: ... @@ -72,39 +83,30 @@ class AsyncConnectCallbackProtocol(Protocol): ConnectCallbackT: TypeAlias = ConnectCallbackProtocol | AsyncConnectCallbackProtocol -class Connection: - pid: Any - host: Any - port: Any - db: Any - username: Any - client_name: Any - password: Any +class AbstractConnection: + pid: int + db: str | int + client_name: str | None + credential_provider: CredentialProvider | None + password: str | None + username: str | None socket_timeout: float | None socket_connect_timeout: float | None - socket_keepalive: Any - socket_keepalive_options: Any - socket_type: Any - retry_on_timeout: Any - retry_on_error: list[type[RedisError]] + retry_on_timeout: bool + retry_on_error: list[type[Exception]] retry: Retry - health_check_interval: Any - next_health_check: int - ssl_context: Any - encoder: Any + health_check_interval: float + next_health_check: float + encoder: Encoder redis_connect_func: ConnectCallbackT | None + def __init__( self, *, - host: str = "localhost", - port: str | int = 6379, db: str | int = 0, password: str | None = None, socket_timeout: float | None = None, socket_connect_timeout: float | None = None, - socket_keepalive: bool = False, - socket_keepalive_options: Mapping[int, int | bytes] | None = None, - socket_type: int = 0, retry_on_timeout: bool = False, retry_on_error: list[type[RedisError]] | _Sentinel = ..., encoding: str = "utf-8", @@ -120,67 +122,125 @@ class Connection: encoder_class: type[Encoder] = ..., credential_provider: CredentialProvider | None = None, ) -> None: ... - def repr_pieces(self): ... + @abstractmethod + def repr_pieces(self) -> list[tuple[str, Any]]: ... @property - def is_connected(self): ... - def register_connect_callback(self, callback) -> None: ... + def is_connected(self) -> bool: ... + def register_connect_callback(self, callback: ConnectCallbackT) -> None: ... def clear_connect_callbacks(self) -> None: ... - def set_parser(self, parser_class) -> None: ... + def set_parser(self, parser_class: type[BaseParser]) -> None: ... async def connect(self) -> None: ... async def on_connect(self) -> None: ... async def disconnect(self, nowait: bool = False) -> None: ... async def check_health(self) -> None: ... - async def send_packed_command(self, command: bytes | str | Iterable[bytes], check_health: bool = True): ... - async def send_command(self, *args, **kwargs) -> None: ... - @overload - async def read_response(self, *, timeout: float, disconnect_on_error: bool = True) -> Incomplete | None: ... - @overload + async def send_packed_command(self, command: bytes | str | Iterable[bytes], check_health: bool = True) -> None: ... + async def send_command(self, *args: Any, **kwargs: Any) -> None: ... + async def can_read_destructive(self) -> bool: ... async def read_response( - self, disable_decoding: bool, timeout: float, *, disconnect_on_error: bool = True - ) -> Incomplete | None: ... - @overload - async def read_response(self, disable_decoding: bool = False, timeout: None = None, *, disconnect_on_error: bool = True): ... + self, disable_decoding: bool = False, timeout: float | None = None, *, disconnect_on_error: bool = True + ) -> EncodableT | list[EncodableT] | None: ... def pack_command(self, *args: EncodableT) -> list[bytes]: ... def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> list[bytes]: ... +class Connection(AbstractConnection): + host: str + port: int + socket_keepalive: bool + socket_keepalive_options: Mapping[int, int | bytes] | None + socket_type: int + + def __init__( + self, + *, + host: str = "localhost", + port: str | int = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Mapping[int, int | bytes] | None = None, + socket_type: int = 0, + # **kwargs forwarded to AbstractConnection. + db: str | int = 0, + password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, + retry_on_timeout: bool = False, + retry_on_error: list[type[RedisError]] | _Sentinel = ..., + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: type[BaseParser] = ..., + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: str | None = None, + username: str | None = None, + retry: Retry | None = None, + redis_connect_func: ConnectCallbackT | None = None, + encoder_class: type[Encoder] = ..., + credential_provider: CredentialProvider | None = None, + ) -> None: ... + def repr_pieces(self) -> list[tuple[str, Any]]: ... + class SSLConnection(Connection): - ssl_context: Any + ssl_context: RedisSSLContext def __init__( self, ssl_keyfile: str | None = None, ssl_certfile: str | None = None, - ssl_cert_reqs: str = "required", + ssl_cert_reqs: _SSLVerifyMode = "required", ssl_ca_certs: str | None = None, ssl_ca_data: str | None = None, ssl_check_hostname: bool = False, - **kwargs, + *, + # **kwargs forwarded to Connection. + host: str = "localhost", + port: str | int = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Mapping[int, int | bytes] | None = None, + socket_type: int = 0, + db: str | int = 0, + password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, + retry_on_timeout: bool = False, + retry_on_error: list[type[RedisError]] | _Sentinel = ..., + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: type[BaseParser] = ..., + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: str | None = None, + username: str | None = None, + retry: Retry | None = None, + redis_connect_func: ConnectCallbackT | None = None, + encoder_class: type[Encoder] = ..., + credential_provider: CredentialProvider | None = None, ) -> None: ... @property - def keyfile(self): ... + def keyfile(self) -> str | None: ... @property - def certfile(self): ... + def certfile(self) -> str | None: ... @property - def cert_reqs(self): ... + def cert_reqs(self) -> ssl.VerifyMode: ... @property - def ca_certs(self): ... + def ca_certs(self) -> str | None: ... @property - def ca_data(self): ... + def ca_data(self) -> str | None: ... @property - def check_hostname(self): ... + def check_hostname(self) -> bool: ... class RedisSSLContext: - keyfile: Any - certfile: Any - cert_reqs: Any - ca_certs: Any - ca_data: Any - check_hostname: Any - context: Any + keyfile: str | None + certfile: str | None + cert_reqs: ssl.VerifyMode + ca_certs: str | None + ca_data: str | None + check_hostname: bool + context: ssl.SSLContext | None def __init__( self, keyfile: str | None = None, certfile: str | None = None, - cert_reqs: str | None = None, + cert_reqs: _SSLVerifyMode | None = None, ca_certs: str | None = None, ca_data: str | None = None, check_hostname: bool = False, @@ -188,53 +248,43 @@ class RedisSSLContext: def get(self) -> ssl.SSLContext: ... class UnixDomainSocketConnection(Connection): - pid: Any - path: Any - db: Any - username: Any - client_name: Any - password: Any - retry_on_timeout: Any - retry_on_error: list[type[RedisError]] - retry: Any - health_check_interval: Any - next_health_check: int - redis_connect_func: ConnectCallbackT | None - encoder: Any + path: str def __init__( self, *, path: str = "", + # **kwargs forwarded to AbstractConnection. db: str | int = 0, - username: str | None = None, password: str | None = None, socket_timeout: float | None = None, socket_connect_timeout: float | None = None, + retry_on_timeout: bool = False, + retry_on_error: list[type[RedisError]] | _Sentinel = ..., encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, - retry_on_timeout: bool = False, - retry_on_error: list[type[RedisError]] | _Sentinel = ..., parser_class: type[BaseParser] = ..., socket_read_size: int = 65536, - health_check_interval: float = 0.0, + health_check_interval: float = 0, client_name: str | None = None, + username: str | None = None, retry: Retry | None = None, redis_connect_func: ConnectCallbackT | None = None, + encoder_class: type[Encoder] = ..., credential_provider: CredentialProvider | None = None, ) -> None: ... - def repr_pieces(self) -> Iterable[tuple[str, str | int]]: ... + def repr_pieces(self) -> list[tuple[str, Any]]: ... -FALSE_STRINGS: Any +FALSE_STRINGS: Final[tuple[str, ...]] -def to_bool(value) -> bool | None: ... +def to_bool(value: object) -> bool | None: ... -URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] +URL_QUERY_ARGUMENT_PARSERS: MappingProxyType[str, Callable[[str], Any]] class ConnectKwargs(TypedDict): username: str password: str - connection_class: type[Connection] + connection_class: type[AbstractConnection] host: str port: int db: int @@ -242,40 +292,72 @@ class ConnectKwargs(TypedDict): def parse_url(url: str) -> ConnectKwargs: ... -class ConnectionPool: +_ConnectionT = TypeVar("_ConnectionT", bound=AbstractConnection) + +class ConnectionPool(Generic[_ConnectionT]): + # kwargs accepts all arguments from the connection class chosen for + # the given URL, except those encoded in the URL itself. @classmethod - def from_url(cls, url: str, **kwargs) -> ConnectionPool: ... - connection_class: Any - connection_kwargs: Any - max_connections: Any - encoder_class: Any + def from_url(cls, url: str, **kwargs: Any) -> Self: ... + + connection_class: type[_ConnectionT] + connection_kwargs: Mapping[str, Any] + max_connections: int + encoder_class: type[Encoder] + pid: int + + @overload def __init__( - self, connection_class: type[Connection] = ..., max_connections: int | None = None, **connection_kwargs + self: ConnectionPool[_ConnectionT], + connection_class: type[_ConnectionT], + max_connections: int | None = None, + # **kwargs are passed to the constructed connection instances. + **connection_kwargs: Any, ) -> None: ... - pid: Any + @overload + def __init__(self: ConnectionPool[Connection], *, max_connections: int | None = None, **connection_kwargs) -> None: ... def reset(self) -> None: ... - async def get_connection(self, command_name, *keys, **options): ... - def get_encoder(self): ... - def make_connection(self): ... - async def release(self, connection: Connection): ... - def owns_connection(self, connection: Connection): ... - async def disconnect(self, inuse_connections: bool = True): ... - -class BlockingConnectionPool(ConnectionPool): - queue_class: Any - timeout: Any + async def get_connection(self, command_name: Unused, *keys: Unused, **options: Unused) -> _ConnectionT: ... + def get_encoder(self) -> Encoder: ... + def make_connection(self) -> _ConnectionT: ... + async def release(self, connection: AbstractConnection) -> None: ... + def owns_connection(self, connection: AbstractConnection) -> bool: ... + async def disconnect(self, inuse_connections: bool = True) -> None: ... + def set_retry(self, retry: Retry) -> None: ... + +class BlockingConnectionPool(ConnectionPool[_ConnectionT]): + queue_class: type[asyncio.Queue[_ConnectionT | None]] + timeout: int | None + pool: asyncio.Queue[_ConnectionT | None] + + @overload def __init__( - self, + self: BlockingConnectionPool[_ConnectionT], + max_connections: int, + timeout: int | None, + connection_class: type[_ConnectionT], + queue_class: type[asyncio.Queue[_ConnectionT | None]] = ..., + # **kwargs are passed to the constructed connection instances. + **connection_kwargs: Any, + ) -> None: ... + @overload + def __init__( + self: BlockingConnectionPool[_ConnectionT], max_connections: int = 50, timeout: int | None = 20, - connection_class: type[Connection] = ..., - queue_class: type[asyncio.Queue[Any]] = ..., - **connection_kwargs, + *, + connection_class: type[_ConnectionT], + queue_class: type[asyncio.Queue[_ConnectionT | None]] = ..., + # **kwargs are passed to the constructed connection instances. + **connection_kwargs: Any, + ) -> None: ... + @overload + def __init__( + self: BlockingConnectionPool[Connection], + max_connections: int = 50, + timeout: int | None = 20, + *, + queue_class: type[asyncio.Queue[Connection | None]] = ..., + # **kwargs are passed to the constructed connection instances. + **connection_kwargs: Any, ) -> None: ... - pool: Any - pid: Any - def reset(self) -> None: ... - def make_connection(self): ... - async def get_connection(self, command_name, *keys, **options): ... - async def release(self, connection: Connection): ... - async def disconnect(self, inuse_connections: bool = True): ... diff --git a/stubs/redis/redis/asyncio/sentinel.pyi b/stubs/redis/redis/asyncio/sentinel.pyi index a25a30cbdbce..a50ebdee57c3 100644 --- a/stubs/redis/redis/asyncio/sentinel.pyi +++ b/stubs/redis/redis/asyncio/sentinel.pyi @@ -1,69 +1,162 @@ -from _typeshed import Incomplete -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence -from typing import Any, overload +from collections.abc import AsyncIterator, Iterable, Mapping +from typing import Any, Literal, TypedDict, TypeVar, overload from redis.asyncio.client import Redis -from redis.asyncio.connection import Connection, ConnectionPool, SSLConnection +from redis.asyncio.connection import ( + BaseParser, + ConnectCallbackT, + Connection, + ConnectionPool, + Encoder, + SSLConnection, + _ConnectionT, + _Sentinel, +) +from redis.asyncio.retry import Retry from redis.commands import AsyncSentinelCommands -from redis.exceptions import ConnectionError -from redis.typing import EncodableT +from redis.credentials import CredentialProvider +from redis.exceptions import ConnectionError, RedisError + +_RedisT = TypeVar("_RedisT", bound=Redis[Any]) class MasterNotFoundError(ConnectionError): ... class SlaveNotFoundError(ConnectionError): ... class SentinelManagedConnection(Connection): - connection_pool: Any - def __init__(self, **kwargs) -> None: ... - async def connect_to(self, address) -> None: ... - async def connect(self): ... - @overload - async def read_response(self, *, timeout: float, disconnect_on_error: bool = True) -> Incomplete | None: ... - @overload - async def read_response( - self, disable_decoding: bool, timeout: float, *, disconnect_on_error: bool = True - ) -> Incomplete | None: ... - @overload - async def read_response(self, disable_decoding: bool = False, timeout: None = None, *, disconnect_on_error: bool = True): ... + connection_pool: ConnectionPool[Any] | None + def __init__( + self, + *, + connection_pool: ConnectionPool[Any] | None, + # **kwargs forwarded to Connection. + host: str = "localhost", + port: str | int = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Mapping[int, int | bytes] | None = None, + socket_type: int = 0, + db: str | int = 0, + password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, + retry_on_timeout: bool = False, + retry_on_error: list[type[RedisError]] | _Sentinel = ..., + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: type[BaseParser] = ..., + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: str | None = None, + username: str | None = None, + retry: Retry | None = None, + redis_connect_func: ConnectCallbackT | None = None, + encoder_class: type[Encoder] = ..., + credential_provider: CredentialProvider | None = None, + ) -> None: ... + async def connect_to(self, address: tuple[str, int]) -> None: ... + async def connect(self) -> None: ... class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): ... -class SentinelConnectionPool(ConnectionPool): - is_master: Any - check_connection: Any - service_name: Any - sentinel_manager: Any - master_address: Any - slave_rr_counter: Any - def __init__(self, service_name, sentinel_manager, **kwargs) -> None: ... - def reset(self) -> None: ... - def owns_connection(self, connection: Connection): ... - async def get_master_address(self): ... - async def rotate_slaves(self) -> AsyncIterator[Any]: ... +class SentinelConnectionPool(ConnectionPool[_ConnectionT]): + is_master: bool + check_connection: bool + service_name: str + sentinel_manager: Sentinel + master_address: tuple[str, int] | None + slave_rr_counter: int | None + + def __init__( + self, + service_name: str, + sentinel_manager: Sentinel, + *, + ssl: bool = False, + connection_class: type[SentinelManagedConnection] = ..., + is_master: bool = True, + check_connection: bool = False, + # **kwargs ultimately forwarded to construction Connection instances. + host: str = "localhost", + port: str | int = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Mapping[int, int | bytes] | None = None, + socket_type: int = 0, + db: str | int = 0, + password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, + retry_on_timeout: bool = False, + retry_on_error: list[type[RedisError]] | _Sentinel = ..., + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: type[BaseParser] = ..., + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: str | None = None, + username: str | None = None, + retry: Retry | None = None, + redis_connect_func: ConnectCallbackT | None = None, + encoder_class: type[Encoder] = ..., + credential_provider: CredentialProvider | None = None, + ) -> None: ... + async def get_master_address(self) -> tuple[str, int]: ... + async def rotate_slaves(self) -> AsyncIterator[tuple[str, int]]: ... + +_State = TypedDict( + "_State", {"ip": str, "port": int, "is_master": bool, "is_sdown": bool, "is_odown": bool, "num-other-sentinels": int} +) class Sentinel(AsyncSentinelCommands): - sentinel_kwargs: Any - sentinels: Any - min_other_sentinels: Any - connection_kwargs: Any + sentinel_kwargs: Mapping[str, Any] + sentinels: list[Redis[Any]] + min_other_sentinels: int + connection_kwargs: Mapping[str, Any] def __init__( - self, sentinels, min_other_sentinels: int = 0, sentinel_kwargs: Incomplete | None = None, **connection_kwargs + self, + sentinels: Iterable[tuple[str, int]], + min_other_sentinels: int = 0, + sentinel_kwargs: Mapping[str, Any] | None = None, + **connection_kwargs: Any, ) -> None: ... - async def execute_command(self, *args, **kwargs): ... - def check_master_state(self, state: dict[Any, Any], service_name: str) -> bool: ... - async def discover_master(self, service_name: str): ... - def filter_slaves(self, slaves: Iterable[Mapping[Any, Any]]) -> Sequence[tuple[EncodableT, EncodableT]]: ... - async def discover_slaves(self, service_name: str) -> Sequence[tuple[EncodableT, EncodableT]]: ... + async def execute_command(self, *args: Any, once: bool = False, **kwargs: Any) -> Literal[True]: ... + def check_master_state(self, state: _State, service_name: str) -> bool: ... + async def discover_master(self, service_name: str) -> tuple[str, int]: ... + def filter_slaves(self, slaves: Iterable[_State]) -> list[tuple[str, int]]: ... + async def discover_slaves(self, service_name: str) -> list[tuple[str, int]]: ... + @overload + def master_for( + self, + service_name: str, + redis_class: type[_RedisT], + connection_pool_class: type[SentinelConnectionPool[Any]] = ..., + # Forwarded to the connection pool constructor. + **kwargs: Any, + ) -> _RedisT: ... + @overload def master_for( self, service_name: str, - redis_class: type[Redis[Any]] = ..., - connection_pool_class: type[SentinelConnectionPool] = ..., - **kwargs, - ): ... + *, + connection_pool_class: type[SentinelConnectionPool[Any]] = ..., + # Forwarded to the connection pool constructor. + **kwargs: Any, + ) -> Redis[Any]: ... + @overload + def slave_for( + self, + service_name: str, + redis_class: type[_RedisT], + connection_pool_class: type[SentinelConnectionPool[Any]] = ..., + # Forwarded to the connection pool constructor. + **kwargs: Any, + ) -> _RedisT: ... + @overload def slave_for( self, service_name: str, - redis_class: type[Redis[Any]] = ..., - connection_pool_class: type[SentinelConnectionPool] = ..., - **kwargs, - ): ... + *, + connection_pool_class: type[SentinelConnectionPool[Any]] = ..., + # Forwarded to the connection pool constructor. + **kwargs: Any, + ) -> Redis[Any]: ... diff --git a/stubs/redis/redis/typing.pyi b/stubs/redis/redis/typing.pyi index f351ed45ac76..f5eb13f831a2 100644 --- a/stubs/redis/redis/typing.pyi +++ b/stubs/redis/redis/typing.pyi @@ -1,6 +1,6 @@ from collections.abc import Iterable from datetime import datetime, timedelta -from typing import Protocol, TypeVar +from typing import Any, Protocol, TypeVar from typing_extensions import TypeAlias from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool @@ -30,5 +30,5 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) # noqa: Y001 AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) # noqa: Y001 class CommandsProtocol(Protocol): - connection_pool: AsyncConnectionPool | ConnectionPool + connection_pool: AsyncConnectionPool[Any] | ConnectionPool def execute_command(self, *args, **options): ...