diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index fd4e8016..c637a7ca 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -1,11 +1,12 @@ import asyncio import logging import random +import typing from ydb import issues from ydb.pool import ConnectionsCache as _ConnectionsCache, IConnectionPool -from .connection import Connection +from .connection import Connection, EndpointKey from . import resolver @@ -21,7 +22,7 @@ def __init__(self, use_all_nodes: bool = False): self._fast_fail_error = None - async def get(self, preferred_endpoint=None, fast_fail=False, wait_timeout=10): + async def get(self, preferred_endpoint: typing.Optional[EndpointKey] = None, fast_fail=False, wait_timeout=10): if fast_fail: await asyncio.wait_for(self._fast_fail_event.wait(), timeout=wait_timeout) diff --git a/ydb/pool.py b/ydb/pool.py index 1f33cf8a..1e75950e 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -5,11 +5,12 @@ from concurrent import futures import collections import random +import typing from . import connection as connection_impl, issues, resolver, _utilities, tracing from abc import abstractmethod -from .connection import Connection +from .connection import Connection, EndpointKey logger = logging.getLogger(__name__) @@ -123,7 +124,7 @@ def subscribe(self): return subscription @tracing.with_trace() - def get(self, preferred_endpoint=None) -> Connection: + def get(self, preferred_endpoint: typing.Optional[EndpointKey] = None) -> Connection: with self.lock: if preferred_endpoint is not None and preferred_endpoint.node_id in self.connections_by_node_id: return self.connections_by_node_id[preferred_endpoint.node_id]