diff --git a/ydb/resolver.py b/ydb/resolver.py index d4fb1aff..b795af92 100644 --- a/ydb/resolver.py +++ b/ydb/resolver.py @@ -1,10 +1,21 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import contextlib import logging import threading import random import itertools -from . import connection as conn_impl, issues, settings as settings_impl, _apis +import typing +from . import connection as conn_impl, driver, issues, settings as settings_impl, _apis + + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ._grpc.v4.protos import ydb_discovery_pb2 +else: + from ._grpc.common.protos import ydb_discovery_pb2 + logger = logging.getLogger(__name__) @@ -22,7 +33,7 @@ class EndpointInfo(object): "node_id", ) - def __init__(self, endpoint_info): + def __init__(self, endpoint_info: ydb_discovery_pb2.EndpointInfo): self.address = endpoint_info.address self.endpoint = "%s:%s" % (endpoint_info.address, endpoint_info.port) self.location = endpoint_info.location @@ -33,7 +44,7 @@ def __init__(self, endpoint_info): self.ssl_target_name_override = endpoint_info.ssl_target_name_override self.node_id = endpoint_info.node_id - def endpoints_with_options(self): + def endpoints_with_options(self) -> typing.Generator[typing.Tuple[str, conn_impl.EndpointOptions], None, None]: ssl_target_name_override = None if self.ssl: if self.ssl_target_name_override: @@ -73,14 +84,14 @@ def __eq__(self, other): return self.endpoint == other.endpoint -def _list_endpoints_request_factory(connection_params): +def _list_endpoints_request_factory(connection_params: driver.DriverConfig) -> _apis.ydb_discovery.ListEndpointsRequest: request = _apis.ydb_discovery.ListEndpointsRequest() request.database = connection_params.database return request class DiscoveryResult(object): - def __init__(self, self_location, endpoints): + def __init__(self, self_location: str, endpoints: "list[EndpointInfo]"): self.self_location = self_location self.endpoints = endpoints @@ -94,7 +105,12 @@ def __repr__(self): return self.__str__() @classmethod - def from_response(cls, rpc_state, response, use_all_nodes=False): + def from_response( + cls, + rpc_state: conn_impl._RpcState, + response: ydb_discovery_pb2.ListEndpointsResponse, + use_all_nodes: bool = False, + ) -> DiscoveryResult: issues._process_response(response.operation) message = _apis.ydb_discovery.ListEndpointsResult() response.operation.result.Unpack(message) @@ -123,7 +139,7 @@ def from_response(cls, rpc_state, response, use_all_nodes=False): class DiscoveryEndpointsResolver(object): - def __init__(self, driver_config): + def __init__(self, driver_config: driver.DriverConfig): self.logger = logger.getChild(self.__class__.__name__) self._driver_config = driver_config self._ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) @@ -136,7 +152,7 @@ def __init__(self, driver_config): random.shuffle(self._endpoints) self._endpoints_iter = itertools.cycle(self._endpoints) - def _add_debug_details(self, message, *args): + def _add_debug_details(self, message: str, *args): self.logger.debug(message, *args) message = message % args with self._lock: @@ -144,19 +160,19 @@ def _add_debug_details(self, message, *args): if len(self._debug_details_items) > self._debug_details_history_size: self._debug_details_items.pop() - def debug_details(self): + def debug_details(self) -> str: """ Returns last resolver errors as a debug string. """ with self._lock: return "\n".join(self._debug_details_items) - def resolve(self): + def resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]: with self.context_resolve() as result: return result @contextlib.contextmanager - def context_resolve(self): + def context_resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]: self.logger.debug("Preparing initial endpoint to resolve endpoints") endpoint = next(self._endpoints_iter) initial = conn_impl.Connection.ready_factory(endpoint, self._driver_config, ready_timeout=self._ready_timeout)